Commit 1a91fcc2 authored by gaoqiong's avatar gaoqiong
Browse files

add dtk所需文件

parent a144865d
Pipeline #492 failed with stages
in 0 seconds
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "activations.h"
#include "core/framework/op_kernel.h"
using namespace onnxruntime::rocm;
namespace onnxruntime {
namespace contrib {
namespace rocm {
#define REGISTER_ACTIVATION_KERNEL(x, ver, domain, T) \
ONNX_OPERATOR_TYPED_KERNEL_EX( \
x, \
domain, \
ver, \
T, \
kRocmExecutionProvider, \
(*KernelDefBuilder::Create()) \
.TypeConstraint("T", DataTypeImpl::GetTensorType<T>()) \
.MayInplace(0, 0), \
x<T>);
#define UNARY_ACTIVATION_COMPUTE(x, T) \
template <> \
Status x<T>::ComputeInternal(OpKernelContext* context) const { \
UnaryElementwisePreparation p; \
ORT_RETURN_IF_ERROR(UnaryElementwise::Prepare(context, &p)); \
Ctx##x func_ctx = MakeFuncCtx(); \
Impl_##x<typename ToHipType<T>::MappedType>( \
Stream(), \
reinterpret_cast<const typename ToHipType<T>::MappedType*>(p.input_tensor->Data<T>()), \
reinterpret_cast<typename ToHipType<T>::MappedType*>(p.output_tensor->MutableData<T>()), \
&func_ctx, p.output_tensor->Shape().Size()); \
\
return Status::OK(); \
}
#define UNARY_ACTIVATION_OP_TYPED(name, ver, domain, T) \
REGISTER_ACTIVATION_KERNEL(name, ver, domain, T) \
UNARY_ACTIVATION_COMPUTE(name, T)
#define UNARY_ACTIVATION_OP_HFD(name, ver, domain) \
UNARY_ACTIVATION_OP_TYPED(name, ver, domain, MLFloat16) \
UNARY_ACTIVATION_OP_TYPED(name, ver, domain, float) \
UNARY_ACTIVATION_OP_TYPED(name, ver, domain, double)
UNARY_ACTIVATION_OP_HFD(Affine, 1, kOnnxDomain);
UNARY_ACTIVATION_OP_HFD(ParametricSoftplus, 1, kOnnxDomain);
UNARY_ACTIVATION_OP_HFD(ScaledTanh, 1, kOnnxDomain);
UNARY_ACTIVATION_OP_HFD(Gelu, 1, kMSDomain);
UNARY_ACTIVATION_OP_HFD(QuickGelu, 1, kMSDomain);
REGISTER_ACTIVATION_KERNEL(ThresholdedRelu, 1, kOnnxDomain, MLFloat16)
REGISTER_ACTIVATION_KERNEL(ThresholdedRelu, 1, kOnnxDomain, float)
REGISTER_ACTIVATION_KERNEL(ThresholdedRelu, 1, kOnnxDomain, double)
} //namespace rocm
} // namespace contrib
} // namespace onnxruntime
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include "core/providers/rocm/rocm_common.h"
#include "core/providers/rocm/math/unary_elementwise_ops.h"
#include "core/providers/rocm/math/binary_elementwise_ops.h"
#include "core/providers/rocm/activation/activations.h"
#include "activations_impl.h"
using namespace onnxruntime::rocm;
namespace onnxruntime {
namespace contrib {
namespace rocm {
template <typename T>
class Affine final : public UnaryElementwise {
public:
Affine(const OpKernelInfo& info) : UnaryElementwise(info) {
ORT_ENFORCE(info.GetAttr("alpha", &alpha_).IsOK());
ORT_ENFORCE(info.GetAttr("beta", &beta_).IsOK());
}
Status ComputeInternal(OpKernelContext* context) const override;
private:
MAKE_FUNC_CTX_ALPHA_BETA()
float alpha_;
float beta_;
};
template <typename T>
class ParametricSoftplus final : public UnaryElementwise {
public:
ParametricSoftplus(const OpKernelInfo& info) : UnaryElementwise(info) {
ORT_ENFORCE(info.GetAttr("alpha", &alpha_).IsOK());
ORT_ENFORCE(info.GetAttr("beta", &beta_).IsOK());
}
Status ComputeInternal(OpKernelContext* context) const override;
private:
MAKE_FUNC_CTX_ALPHA_BETA()
float alpha_;
float beta_;
};
template <typename T>
class ScaledTanh final : public UnaryElementwise {
public:
ScaledTanh(const OpKernelInfo& info) : UnaryElementwise(info) {
ORT_ENFORCE(info.GetAttr("alpha", &alpha_).IsOK());
ORT_ENFORCE(info.GetAttr("beta", &beta_).IsOK());
}
Status ComputeInternal(OpKernelContext* context) const override;
private:
MAKE_FUNC_CTX_ALPHA_BETA()
float alpha_;
float beta_;
};
template <typename T>
class Gelu final : public UnaryElementwise {
public:
Gelu(const OpKernelInfo& info) : UnaryElementwise(info) {}
Status ComputeInternal(OpKernelContext* context) const override;
private:
MAKE_FUNC_CTX_NULL()
};
template <typename T>
class QuickGelu final : public UnaryElementwise {
public:
QuickGelu(const OpKernelInfo& info) : UnaryElementwise(info) {
alpha_ = info.GetAttrOrDefault<float>("alpha", 1.702f);
}
Status ComputeInternal(OpKernelContext* context) const override;
private:
MAKE_FUNC_CTX_ALPHA()
float alpha_;
};
} // namespace rocm
} // namespace contrib
} // namespace onnxruntime
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include <hip/hip_runtime.h>
#include "activations_impl.h"
#include "core/providers/rocm/cu_inc/common.cuh"
#include "core/providers/rocm/cu_inc/unary_elementwise_impl.cuh"
using namespace onnxruntime::rocm;
namespace onnxruntime {
namespace contrib {
namespace rocm {
template <typename T>
struct OP_Affine : public CtxAffine {
__device__ __inline__ T operator()(const T& a) const {
return a * (T)alpha + (T)beta;
}
};
template <typename T>
struct OP_ParametricSoftplus : public CtxParametricSoftplus {
__device__ __inline__ T operator()(const T& a) const {
if (a > (T)0)
return (T)alpha * (a * (T)beta + _Log(_Exp(-a * (T)beta) + (T)1));
else
return (T)alpha * _Log(_Exp(a * (T)beta) + (T)1);
}
};
template <typename T>
struct OP_ScaledTanh : public CtxScaledTanh {
__device__ __inline__ T operator()(const T& a) const {
return (T)alpha * _Tanh(a * (T)beta);
}
};
template <typename T>
struct OP_Gelu : public CtxGelu {
__device__ __inline__ T operator()(const T& a) const {
return _Gelu(a);
}
};
template <>
struct OP_Gelu<half> : public CtxGelu {
__device__ __inline__ half operator()(const half& a) const {
return static_cast<half>(_Gelu(static_cast<float>(a)));
}
};
template <typename T>
struct OP_QuickGelu : public CtxQuickGelu {
__device__ __inline__ T operator()(const T& a) const {
T v = a * static_cast<T>(alpha);
T one = static_cast<T>(1.f);
T zero = static_cast<T>(0.f);
T sigmoid = v >= zero ? one / (one + _Exp(-v)) : one - one / (one + _Exp(v));
return a * sigmoid;
}
};
#define UNARY_ACTIVATION_IMPL(name) \
UNARY_ACTIVATION_IMPL_DECLARATION(name) { \
UnaryElementWiseImpl(stream, \
input_data, \
output_data, \
*reinterpret_cast<const OP_##name<T>*>(func_ctx), \
count); \
}
#define SPECIALIZED_UNARY_ACTIVATION_IMPL(name, T) \
template void Impl_##name<T>(hipStream_t stream, const T* input_data, T* output_data, const Ctx##name* func_ctx, size_t count);
#define SPECIALIZED_UNARY_ACTIVATIONL_HFD(name) \
SPECIALIZED_UNARY_ACTIVATION_IMPL(name, half) \
SPECIALIZED_UNARY_ACTIVATION_IMPL(name, float) \
SPECIALIZED_UNARY_ACTIVATION_IMPL(name, double)
#define UNARY_ACTIVATION_OP_NAME(name) \
UNARY_ACTIVATION_IMPL(name); \
SPECIALIZED_UNARY_ACTIVATIONL_HFD(name)
UNARY_CONTRIB_ACTIVATION_OPS()
#undef UNARY_ACTIVATION_OP_NAME
} // namespace rocm
} // namespace contrib
} // namespace onnxruntime
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include "core/providers/rocm/activation/activations_impl.h"
namespace onnxruntime {
namespace contrib {
namespace rocm {
typedef onnxruntime::rocm::CtxAlphaBeta CtxAffine;
typedef onnxruntime::rocm::CtxAlphaBeta CtxParametricSoftplus;
typedef onnxruntime::rocm::CtxAlphaBeta CtxScaledTanh;
typedef onnxruntime::rocm::CtxNull CtxGelu;
typedef onnxruntime::rocm::CtxAlpha CtxQuickGelu;
#define UNARY_CONTRIB_ACTIVATION_OPS() \
UNARY_ACTIVATION_OP_NAME(ScaledTanh) \
UNARY_ACTIVATION_OP_NAME(Affine) \
UNARY_ACTIVATION_OP_NAME(ParametricSoftplus) \
UNARY_ACTIVATION_OP_NAME(Gelu) \
UNARY_ACTIVATION_OP_NAME(QuickGelu)
#define UNARY_ACTIVATION_OP_NAME(name) UNARY_ACTIVATION_IMPL_DECLARATION(name);
UNARY_CONTRIB_ACTIVATION_OPS()
#undef UNARY_ACTIVATION_OP_NAME
} // namespace rocm
} // namespace contrib
} // namespace onnxruntime
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "core/providers/shared_library/provider_api.h"
#include "contrib_ops/cpu/aten_ops/aten_op.h"
namespace onnxruntime {
namespace contrib {
namespace rocm {
ONNX_OPERATOR_KERNEL_EX(
ATen, kPytorchAtenDomain, 1, kRocmExecutionProvider,
(*KernelDefBuilder::Create()).TypeConstraint("T", DataTypeImpl::AllTensorAndSequenceTensorTypes()),
onnxruntime::contrib::ATen);
} // namespace rocm
} // namespace contrib
} // namespace onnxruntime
#include "hip/hip_runtime.h"
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "core/providers/rocm/rocm_common.h"
#include "core/providers/rocm/cu_inc/common.cuh"
#include "contrib_ops/rocm/bert/add_bias_transpose.h"
namespace onnxruntime {
namespace rocm {
struct __align__(8) Half4 {
half2 x;
half2 y;
};
__device__ __forceinline__ Half4 operator+(const Half4& a, const Half4& b) {
Half4 r;
r.x = a.x + b.x;
r.y = a.y + b.y;
return r;
}
__device__ __forceinline__ float2 operator+(const float2& a, const float2& b) {
return make_float2(a.x + b.x, a.y + b.y);
}
__device__ __forceinline__ float4 operator+(const float4& a, const float4& b) {
return make_float4(a.x + b.x, a.y + b.y, a.z + b.z, a.w + b.w);
}
} // namespace rocm
} // namespace onnxruntime
using namespace onnxruntime::rocm;
namespace onnxruntime {
namespace contrib {
namespace rocm {
template <typename T>
__global__ void AddBiasTransposeTrt(const T* input, const T* biases, T* output) {
// Input: BxSxMxNxH (Format 2)
// Output: BxSxNxMxH
// B is batch_size, S is sequence_length, M is number of matrices, N is num_heads, H is head_size
int n = threadIdx.y;
int s = blockIdx.x;
int b = blockIdx.y;
int m = blockIdx.z; // matrix id
const int H = blockDim.x;
const int N = blockDim.y;
const int S = gridDim.x;
const int M = gridDim.z;
const int NH = N * H;
const int offset = (b * S + s) * M * NH;
const int in_offset = offset + m * NH + n * H;
const int out_offset = offset + (n * M + m) * H;
const int h = threadIdx.x;
if (h < H) {
output[out_offset + h] = input[in_offset + h] + biases[m * NH + n * H + h];
}
}
template <typename T>
__global__ void AddBiasTransposeTrtLarge(const int head_size, const T* input, const T* biases, T* output) {
int n = threadIdx.y;
int s = blockIdx.x;
int b = blockIdx.y;
int m = blockIdx.z;
const int stride = blockDim.x;
const int H = head_size;
const int N = blockDim.y;
const int S = gridDim.x;
const int M = gridDim.z;
const int NH = N * H;
const int offset = (b * S + s) * M * NH;
const int in_offset = offset + m * NH + n * H;
const int out_offset = offset + (n * M + m) * H;
int h = threadIdx.x;
while (h < H) {
output[out_offset + h] = input[in_offset + h] + biases[m * NH + n * H + h];
h += stride;
}
}
template <typename T>
__global__ void AddBiasTransposeTrt(const T* query, const T* key, const T* value, const T* biases, T* output) {
// Q: BxSxNxH
// K: BxSxNxH
// V: BxSxNxH
// Output: BxSxNxMxH
// B is batch_size, S is sequence_length, M is number of matrices (3), N is num_heads, H is head_size
int n = threadIdx.y;
int s = blockIdx.x;
int b = blockIdx.y;
int m = blockIdx.z; // matrix id
const int H = blockDim.x;
const int N = blockDim.y;
const int S = gridDim.x;
const int M = gridDim.z;
const T* input = (m == 0 ? query : (m == 1 ? key : value));
const int NH = N * H;
const int in_offset = (b * S + s) * NH + n * H;
const int out_offset = (b * S + s) * M * NH + (n * M + m) * H;
const int h = threadIdx.x;
if (h < H) {
output[out_offset + h] = input[in_offset + h] + biases[m * NH + n * H + h];
}
}
template <typename T>
__global__ void AddBiasTransposeTrtLarge(const int head_size,
const T* query, const T* key, const T* value, const T* biases, T* output) {
int n = threadIdx.y;
int s = blockIdx.x;
int b = blockIdx.y;
int m = blockIdx.z; // matrix id
const int stride = blockDim.x;
const int H = head_size;
const int N = blockDim.y;
const int S = gridDim.x;
const int M = gridDim.z;
const T* input = (m == 0 ? query : (m == 1 ? key : value));
const int NH = N * H;
const int in_offset = (b * S + s) * NH + n * H;
const int out_offset = (b * S + s) * M * NH + (n * M + m) * H;
int h = threadIdx.x;
if (h < H) {
output[out_offset + h] = input[in_offset + h] + biases[m * NH + n * H + h];
h += stride;
}
}
template <typename T>
__global__ void AddBiasTransposeQKV(const T* input, const T* biases, T* output) {
// Input: BxSxMxNxH (Format 1)
// Output: MxBxNxSxH
// B is batch_size, S is sequence_length, M is number of matrices, N is num_heads, H is head_size
int n = threadIdx.y;
int s = blockIdx.x;
int b = blockIdx.y;
int m = blockIdx.z; // matrix id
const int head_size = blockDim.x;
const int num_heads = blockDim.y;
const int sequence_length = gridDim.x;
const int batch_size = gridDim.y;
const int M = gridDim.z;
const int H = head_size;
const int NH = num_heads * head_size;
const int NHS = NH * sequence_length;
int in_offset = n * head_size + (m + s * M) * NH + b * NHS * M;
const int out_offset = s * head_size + n * sequence_length * H + b * NHS + m * NHS * batch_size;
const int h = threadIdx.x;
if (h < head_size) {
output[out_offset + h] = input[in_offset + h] + biases[m * NH + n * H + h];
}
}
template <typename T>
__global__ void AddBiasTransposeQKV(const T* input, const T* biases, T* output, int v_head_size) {
// Input: BxSxMxNxH (Format 1)
// Output: MxBxNxSxH
// B is batch_size, S is sequence_length, M is number of matrices, N is num_heads, H is head_size
int n = threadIdx.y; // head_num_id
int s = blockIdx.x; // sequence_id
int b = blockIdx.y; // batch_id
int m = blockIdx.z; // matrix id (Q=0, K=1, V=2)
const int h = threadIdx.x; // head_element_id
const int qk_head_size = blockDim.x;
const int num_heads = blockDim.y;
const int sequence_length = gridDim.x;
const int batch_size = gridDim.y;
const int head_size = (m == 2 ? v_head_size : qk_head_size);
const int total_head_size = num_heads * (qk_head_size + qk_head_size + v_head_size);
int in_offset;
int out_offset;
int bias_offset;
in_offset = b * (total_head_size * sequence_length) + // B
s * (total_head_size) + // S
m * (qk_head_size * num_heads) + // M
n * head_size + // N
h; // H
out_offset = m * (num_heads * qk_head_size * sequence_length * batch_size) + // M
b * (num_heads * head_size * sequence_length) + // B
n * (sequence_length * head_size) + // N
s * (head_size) + // S
h; // H
bias_offset = m * (num_heads * qk_head_size) + // M
n * (head_size) + // N
h; // H
if (h < head_size) {
output[out_offset] = input[in_offset] + biases[bias_offset];
}
}
template <typename T>
__global__ void AddBiasTransposeQKVLarge(const int head_size, const T* input, const T* biases, T* output) {
int n = threadIdx.y;
int s = blockIdx.x;
int b = blockIdx.y;
int m = blockIdx.z;
const int stride = blockDim.x;
const int num_heads = blockDim.y;
const int sequence_length = gridDim.x;
const int batch_size = gridDim.y;
const int M = gridDim.z;
const int H = head_size;
const int NH = num_heads * H;
const int NHS = NH * sequence_length;
int in_offset = n * H + (m + s * M) * NH + b * NHS * M;
const int out_offset = s * H + n * sequence_length * H + b * NHS + m * NHS * batch_size;
int h = threadIdx.x;
while (h < H) {
output[out_offset + h] = input[in_offset + h] + biases[m * NH + n * H + h];
h += stride;
}
}
template <typename T>
__global__ void AddBiasTranspose(const T* input, const T* biases, T* output) {
// Input: MxBxSxNxH (Format 0)
// Output: MxBxNxSxH
// B is batch_size, S is sequence_length, M is number of matrices, N is num_heads, H is head_size
int n = threadIdx.y;
int s = blockIdx.x;
int b = blockIdx.y;
int m = blockIdx.z;
const int head_size = blockDim.x;
const int num_heads = blockDim.y;
const int sequence_length = gridDim.x;
const int batch_size = gridDim.y;
const int H = head_size;
const int NH = num_heads * head_size;
const int NHS = NH * sequence_length;
int in_offset = n * H + s * NH + (b + m * batch_size) * NHS;
const int out_offset = s * H + n * sequence_length * H + (b + m * batch_size) * NHS;
const int h = threadIdx.x;
if (h < head_size) {
output[out_offset + h] = input[in_offset + h] + biases[m * NH + n * H + h];
}
}
template <typename T>
__global__ void AddBiasTransposeLarge(const int head_size, const T* input, const T* biases, T* output) {
int n = threadIdx.y;
int s = blockIdx.x;
int b = blockIdx.y;
int m = blockIdx.z;
const int stride = blockDim.x;
const int num_heads = blockDim.y;
const int sequence_length = gridDim.x;
const int batch_size = gridDim.y;
const int H = head_size;
const int NH = num_heads * H;
const int NHS = NH * sequence_length;
int in_offset = n * H + s * NH + (b + m * batch_size) * NHS;
const int out_offset = (s + n * sequence_length) * H + (b + m * batch_size) * NHS;
int h = threadIdx.x;
while (h < H) {
output[out_offset + h] = input[in_offset + h] + biases[m * NH + n * H + h];
h += stride;
}
}
template <typename T>
void InvokeAddBiasTranspose(
hipStream_t stream, const int num_matrices, const int format, const int max_threads_per_block,
const int batch_size, const int sequence_length, const int num_heads, const int qk_head_size,
const T* input, const T* biases, T* output, const int v_head_size) {
const dim3 grid(sequence_length, batch_size, num_matrices);
if (qk_head_size * num_heads <= max_threads_per_block) {
const dim3 block(qk_head_size, num_heads, 1);
if (format == 2) {
hipLaunchKernelGGL(HIP_KERNEL_NAME(AddBiasTransposeTrt<T>), grid, block, 0, stream, input, biases, output);
} else if (format == 1) {
if (v_head_size == -1 || qk_head_size == v_head_size) {
hipLaunchKernelGGL(HIP_KERNEL_NAME(AddBiasTransposeQKV<T>), grid, block, 0, stream, input, biases, output);
} else {
hipLaunchKernelGGL(HIP_KERNEL_NAME(AddBiasTransposeQKV<T>), grid, block, 0, stream, input, biases, output, v_head_size);
}
} else {
hipLaunchKernelGGL(HIP_KERNEL_NAME(AddBiasTranspose<T>), grid, block, 0, stream, input, biases, output);
}
} else {
const dim3 block(CeilDiv(max_threads_per_block, num_heads), num_heads, 1);
if (format == 2) {
hipLaunchKernelGGL(HIP_KERNEL_NAME(AddBiasTransposeTrtLarge<T>), grid, block, 0, stream, qk_head_size, input, biases, output);
} else if (format == 1) {
if (v_head_size == -1 || qk_head_size == v_head_size) {
hipLaunchKernelGGL(HIP_KERNEL_NAME(AddBiasTransposeQKVLarge<T>), grid, block, 0, stream, qk_head_size, input, biases, output);
} else {
ORT_THROW("AddBiasTranspose (format 1) not implemented for hidden_size > max_threads_per_block");
}
} else {
hipLaunchKernelGGL(HIP_KERNEL_NAME(AddBiasTransposeLarge<T>), grid, block, 0, stream, qk_head_size, input, biases, output);
}
}
}
template <>
void LaunchAddBiasTranspose(
hipStream_t stream, const int num_matrices, const int format, const int max_threads_per_block,
const int batch_size, const int sequence_length, const int num_heads, const int qk_head_size,
const half* input, const half* biases, half* output,
bool enable_half4, const int v_head_size) {
if (enable_half4 && 0 == (qk_head_size % 4) && 0 == (v_head_size % 4)) {
const int H = qk_head_size / 4;
const int H_v = v_head_size / 4;
const Half4* input2 = reinterpret_cast<const Half4*>(input);
const Half4* biases2 = reinterpret_cast<const Half4*>(biases);
Half4* output2 = reinterpret_cast<Half4*>(output);
InvokeAddBiasTranspose<Half4>(stream, num_matrices, format, max_threads_per_block,
batch_size, sequence_length, num_heads, H, input2, biases2, output2, H_v);
} else if (0 == (qk_head_size & 1) && 0 == (v_head_size % 1)) {
const int H = qk_head_size / 2;
const int H_v = v_head_size / 2;
const half2* input2 = reinterpret_cast<const half2*>(input);
const half2* biases2 = reinterpret_cast<const half2*>(biases);
half2* output2 = reinterpret_cast<half2*>(output);
InvokeAddBiasTranspose<half2>(stream, num_matrices, format, max_threads_per_block,
batch_size, sequence_length, num_heads, H, input2, biases2, output2, H_v);
} else {
InvokeAddBiasTranspose<half>(
stream, num_matrices, format, max_threads_per_block,
batch_size, sequence_length, num_heads, qk_head_size, input, biases, output, v_head_size);
}
}
template <>
void LaunchAddBiasTranspose(
hipStream_t stream, const int num_matrices, const int format, const int max_threads_per_block,
const int batch_size, const int sequence_length, const int num_heads, const int qk_head_size,
const float* input, const float* biases, float* output,
bool /*enable_half4*/, const int v_head_size) {
if (0 == (qk_head_size % 4)) {
const int H = qk_head_size / 4;
const float4* input2 = reinterpret_cast<const float4*>(input);
const float4* biases2 = reinterpret_cast<const float4*>(biases);
float4* output2 = reinterpret_cast<float4*>(output);
InvokeAddBiasTranspose<float4>(
stream, num_matrices, format, max_threads_per_block,
batch_size, sequence_length, num_heads, H, input2, biases2, output2, v_head_size / 4);
} else if (0 == (qk_head_size & 1)) {
const int H = qk_head_size / 2;
const float2* input2 = reinterpret_cast<const float2*>(input);
const float2* biases2 = reinterpret_cast<const float2*>(biases);
float2* output2 = reinterpret_cast<float2*>(output);
InvokeAddBiasTranspose<float2>(
stream, num_matrices, format, max_threads_per_block,
batch_size, sequence_length, num_heads, H, input2, biases2, output2, v_head_size / 2);
} else {
InvokeAddBiasTranspose<float>(
stream, num_matrices, format, max_threads_per_block,
batch_size, sequence_length, num_heads, qk_head_size, input, biases, output, v_head_size);
}
}
template <typename T>
void InvokeAddBiasTransposeTrt(
hipStream_t stream, const int max_threads_per_block,
const int batch_size, const int sequence_length, const int num_heads, const int head_size,
const T* biases, const T* query, const T* key, const T* value, T* output) {
constexpr int num_matrices = 3;
const dim3 grid(sequence_length, batch_size, num_matrices);
if (head_size * num_heads <= max_threads_per_block) {
const dim3 block(head_size, num_heads, 1);
hipLaunchKernelGGL(HIP_KERNEL_NAME(AddBiasTransposeTrt<T>), grid, block, 0, stream, query, key, value, biases, output);
} else {
const dim3 block(CeilDiv(max_threads_per_block, num_heads), num_heads, 1);
hipLaunchKernelGGL(HIP_KERNEL_NAME(AddBiasTransposeTrtLarge<T>), grid, block, 0, stream, head_size, query, key, value, biases, output);
}
}
template <>
void LaunchAddBiasTransposeTrt(
hipStream_t stream, const int max_threads_per_block,
const int batch_size, const int sequence_length,
const int num_heads, const int head_size,
const float* biases, const float* query, const float* key, const float* value, float* output) {
ORT_ENFORCE(false, "Shall not call this since fused kernel does not support float input.");
}
template <>
void LaunchAddBiasTransposeTrt(
hipStream_t stream, const int max_threads_per_block,
const int batch_size, const int sequence_length,
const int num_heads, const int head_size,
const half* biases, const half* query, const half* key, const half* value, half* output) {
if (0 == (head_size % 4)) {
const int H = head_size / 4;
const Half4* query2 = reinterpret_cast<const Half4*>(query);
const Half4* key2 = reinterpret_cast<const Half4*>(key);
const Half4* value2 = reinterpret_cast<const Half4*>(value);
const Half4* biases2 = reinterpret_cast<const Half4*>(biases);
Half4* output2 = reinterpret_cast<Half4*>(output);
InvokeAddBiasTransposeTrt<Half4>(stream, max_threads_per_block,
batch_size, sequence_length, num_heads, H,
biases2, query2, key2, value2, output2);
} else if (0 == (head_size & 1)) {
const int H = head_size / 2;
const half2* query2 = reinterpret_cast<const half2*>(query);
const half2* key2 = reinterpret_cast<const half2*>(key);
const half2* value2 = reinterpret_cast<const half2*>(value);
const half2* biases2 = reinterpret_cast<const half2*>(biases);
half2* output2 = reinterpret_cast<half2*>(output);
InvokeAddBiasTransposeTrt<half2>(stream, max_threads_per_block,
batch_size, sequence_length, num_heads, H,
biases2, query2, key2, value2, output2);
} else {
InvokeAddBiasTransposeTrt<half>(stream, max_threads_per_block,
batch_size, sequence_length, num_heads, head_size,
biases, query, key, value, output);
}
}
} // namespace rocm
} // namespace contrib
} // namespace onnxruntime
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include "core/providers/rocm/shared_inc/rocm_utils.h"
namespace onnxruntime {
namespace contrib {
namespace rocm {
// Fused kernel of Add (bias) and Transpose.
// Shape of inputs and outputs:
// biases: (num_matrices, num_heads * head_size)
// format 0:
// input: (num_matrices, batch_size, sequence_length, num_heads, head_size)
// output: (num_matrices, batch_size, num_heads, sequence_length, head_size)
// format 1:
// input : (batch_size, sequence_length, num_matrices, num_heads, head_size)
// output: (num_matrices, batch_size, num_heads, sequence_length, head_size)
// format 2:
// input : (batch_size, sequence_length, num_matrices, num_heads, head_size)
// output: (batch_size, sequence_length, num_heads, num_matrices, head_size)
template <typename T>
void LaunchAddBiasTranspose(
hipStream_t stream, const int num_matrices, const int format, const int max_threads_per_block,
const int batch_size, const int sequence_length, const int num_heads, const int qk_head_size,
const T* input, const T* biases, T* output, bool enable_half4, const int v_head_size);
// Add (bias) and Transpose for separated inputs of Q, K and V, and output Trt format.
// output: (batch_size, sequence_length, num_heads, num_matrices, head_size)
// It assumes sequence_length == kv_sequence_length and head_size == v_head_size.
template <typename T>
void LaunchAddBiasTransposeTrt(
hipStream_t stream, const int max_threads_per_block,
const int batch_size, const int sequence_length,
const int num_heads, const int head_size,
const T* biases, const T* query, const T* key, const T* value, T* output);
} // namespace rocm
} // namespace contrib
} // namespace onnxruntime
#include "hip/hip_runtime.h"
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "core/providers/rocm/rocm_common.h"
#include "contrib_ops/rocm/bert/attention_impl.h"
using namespace onnxruntime::rocm;
namespace onnxruntime {
namespace contrib {
namespace rocm {
template <typename T>
__global__ void ConcatTensorToTensor(const int tensor_add_sequence_length,
const T* tensor_in,
const T* tensor_add,
T* tensor_out) {
const int h = threadIdx.x;
const int n = threadIdx.y;
const int s = blockIdx.x;
const int b = blockIdx.y;
const int chunk_id = blockIdx.z;
const int all_sequence_length = gridDim.x;
const int batch_size = gridDim.y;
const int num_heads = blockDim.y;
const int H = blockDim.x;
// K: number of identical tensors
// tensor_in: K x BxNxPxH
// tensor_add: K x BxNxLxH
// tensor_out: K x BxNxTxH, where T = P + L
const int tensor_in_sequence_length = all_sequence_length - tensor_add_sequence_length;
const int present_SH = all_sequence_length * H;
const int present_NSH = num_heads * present_SH;
int out_offset = b * present_NSH + n * present_SH + s * H + h + chunk_id * (present_NSH * batch_size);
if (s < tensor_in_sequence_length) {
const int past_SH = tensor_in_sequence_length * H;
const int past_NSH = num_heads * past_SH;
const int in_offset = b * past_NSH + n * past_SH + s * H + h + chunk_id * (past_NSH * batch_size);
tensor_out[out_offset] = tensor_in[in_offset];
} else if (s < all_sequence_length) {
const int SH = tensor_add_sequence_length * H;
const int NSH = num_heads * SH;
const int in_offset = b * NSH + n * SH + (s - tensor_in_sequence_length) * H + h + chunk_id * (NSH * batch_size);
tensor_out[out_offset] = tensor_add[in_offset];
}
}
template <typename T>
__global__ void ConcatTensorToTensorLarge(const int tensor_add_sequence_length,
const int H,
const T* tensor_in,
const T* tensor_add,
T* tensor_out) {
// Use when (H*)*num_heads > 1024
int h = threadIdx.x;
const int n = threadIdx.y;
const int s = blockIdx.x;
const int b = blockIdx.y;
const int chunk_id = blockIdx.z;
const int all_sequence_length = gridDim.x;
const int batch_size = gridDim.y;
const int num_heads = blockDim.y;
const int stride = blockDim.x;
// K: number of identical tensor
// tensor_in: K x BxNxPxH
// tensor_add: K x BxNxLxH
// tensor_out: K x BxNxTxH
const int tensor_in_sequence_length = all_sequence_length - tensor_add_sequence_length;
const int present_SH = all_sequence_length * H;
const int present_NSH = num_heads * present_SH;
while (h < H) {
int out_offset = b * present_NSH + n * present_SH + s * H + h + chunk_id * (present_NSH * batch_size);
if (s < tensor_in_sequence_length) {
const int past_SH = tensor_in_sequence_length * H;
const int past_NSH = num_heads * past_SH;
const int in_offset = b * past_NSH + n * past_SH + s * H + h + chunk_id * (past_NSH * batch_size);
tensor_out[out_offset] = tensor_in[in_offset];
} else if (s < all_sequence_length) {
const int SH = tensor_add_sequence_length * H;
const int NSH = num_heads * SH;
const int in_offset = b * NSH + n * SH + (s - tensor_in_sequence_length) * H + h + chunk_id * (NSH * batch_size);
tensor_out[out_offset] = tensor_add[in_offset];
}
h += stride;
}
}
Status LaunchConcatTensorToTensor(hipStream_t stream,
const int all_sequence_length,
const int sequence_length,
const int batch_size,
const int head_size,
const int num_heads,
const int max_threads_per_block,
const int matrix_num,
const float* tensor_in,
const float* tensor_add,
float* tensor_out) {
const dim3 grid(all_sequence_length, batch_size, matrix_num);
if (0 == (head_size & 1)) {
const int H = head_size / 2;
if (H * num_heads <= max_threads_per_block) {
const dim3 block(H, num_heads, 1);
hipLaunchKernelGGL(HIP_KERNEL_NAME(ConcatTensorToTensor<float2>), grid, block, 0, stream, sequence_length,
reinterpret_cast<const float2*>(tensor_in),
reinterpret_cast<const float2*>(tensor_add),
reinterpret_cast<float2*>(tensor_out));
} else {
const dim3 block(max_threads_per_block / num_heads, num_heads, 1);
hipLaunchKernelGGL(HIP_KERNEL_NAME(ConcatTensorToTensorLarge<float2>), grid, block, 0, stream, sequence_length,
H,
reinterpret_cast<const float2*>(tensor_in),
reinterpret_cast<const float2*>(tensor_add),
reinterpret_cast<float2*>(tensor_out));
}
} else {
if (head_size * num_heads <= max_threads_per_block) {
const dim3 block(head_size, num_heads, 1);
hipLaunchKernelGGL(HIP_KERNEL_NAME(ConcatTensorToTensor<float>), grid, block, 0, stream, sequence_length, tensor_in, tensor_add, tensor_out);
} else {
const dim3 block(max_threads_per_block / num_heads, num_heads, 1);
hipLaunchKernelGGL(HIP_KERNEL_NAME(ConcatTensorToTensorLarge<float>), grid, block, 0, stream, sequence_length,
head_size,
tensor_in,
tensor_add,
tensor_out);
}
}
return HIP_CALL(hipGetLastError());
}
Status LaunchConcatTensorToTensor(hipStream_t stream,
const int all_sequence_length,
const int sequence_length,
const int batch_size,
const int head_size,
const int num_heads,
const int max_threads_per_block,
const int matrix_num,
const half* tensor_in,
const half* tensor_add,
half* tensor_out) {
const dim3 grid(all_sequence_length, batch_size, matrix_num);
if (0 == (head_size % 4)) {
const int H = head_size / 4;
if (H * num_heads <= max_threads_per_block) {
const dim3 block(H, num_heads, 1);
hipLaunchKernelGGL(HIP_KERNEL_NAME(ConcatTensorToTensor<float2>), grid, block, 0, stream, sequence_length,
reinterpret_cast<const float2*>(tensor_in),
reinterpret_cast<const float2*>(tensor_add),
reinterpret_cast<float2*>(tensor_out));
} else {
const dim3 block(max_threads_per_block / num_heads, num_heads, 1);
hipLaunchKernelGGL(HIP_KERNEL_NAME(ConcatTensorToTensorLarge<float2>), grid, block, 0, stream, sequence_length,
H,
reinterpret_cast<const float2*>(tensor_in),
reinterpret_cast<const float2*>(tensor_add),
reinterpret_cast<float2*>(tensor_out));
}
} else if (0 == (head_size & 1)) {
const int H = head_size / 2;
if (H * num_heads <= max_threads_per_block) {
const dim3 block(H, num_heads, 1);
hipLaunchKernelGGL(HIP_KERNEL_NAME(ConcatTensorToTensor<half2>), grid, block, 0, stream, sequence_length,
reinterpret_cast<const half2*>(tensor_in),
reinterpret_cast<const half2*>(tensor_add),
reinterpret_cast<half2*>(tensor_out));
} else {
const dim3 block(max_threads_per_block / num_heads, num_heads, 1);
hipLaunchKernelGGL(HIP_KERNEL_NAME(ConcatTensorToTensorLarge<half2>), grid, block, 0, stream, sequence_length,
H,
reinterpret_cast<const half2*>(tensor_in),
reinterpret_cast<const half2*>(tensor_add),
reinterpret_cast<half2*>(tensor_out));
}
} else { // this should be an "odd" case. probably not worth catching it in the half2 kernel.
if (head_size * num_heads <= max_threads_per_block) {
const dim3 block(head_size, num_heads, 1);
hipLaunchKernelGGL(HIP_KERNEL_NAME(ConcatTensorToTensor<half>), grid, block, 0, stream, sequence_length, tensor_in, tensor_add, tensor_out);
} else {
const dim3 block(max_threads_per_block / num_heads, num_heads, 1);
hipLaunchKernelGGL(HIP_KERNEL_NAME(ConcatTensorToTensorLarge<half>), grid, block, 0, stream, sequence_length,
head_size,
tensor_in,
tensor_add,
tensor_out);
}
}
return HIP_CALL(hipGetLastError());
}
Status LaunchConcatPastToPresent(hipStream_t stream,
const int all_sequence_length,
const int sequence_length,
const int batch_size,
const int head_size,
const int num_heads,
const int max_threads_per_block,
const float* past,
const float* k_v,
float* present) {
return LaunchConcatTensorToTensor(
stream,
all_sequence_length,
sequence_length,
batch_size,
head_size,
num_heads,
max_threads_per_block,
2,
past,
k_v,
present);
}
Status LaunchConcatPastToPresent(hipStream_t stream,
const int all_sequence_length,
const int sequence_length,
const int batch_size,
const int head_size,
const int num_heads,
const int max_threads_per_block,
const half* past,
const half* k_v,
half* present) {
return LaunchConcatTensorToTensor(
stream,
all_sequence_length,
sequence_length,
batch_size,
head_size,
num_heads,
max_threads_per_block,
2,
past,
k_v,
present);
}
} // namespace rocm
} // namespace contrib
} // namespace onnxruntime
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include "core/providers/rocm/shared_inc/rocm_utils.h"
#include <hip/hip_fp16.h>
#include <rocblas.h>
#include "contrib_ops/cpu/bert/attention_common.h"
namespace onnxruntime {
namespace contrib {
namespace rocm {
size_t GetAttentionScratchSize(
size_t element_size,
size_t batch_size,
size_t num_heads,
size_t sequence_length,
size_t all_sequence_length);
size_t GetAttentionWorkspaceSize(
size_t element_size,
size_t batchsize,
size_t num_heads,
size_t qk_head_size,
size_t v_head_size,
size_t sequence_length,
size_t kv_sequence_length,
size_t total_sequence_length,
void* fused_runner);
template <typename T>
struct AttentionData {
const T* gemm_buffer;
const T* bias;
const T* query;
const T* key;
const T* value;
const int* mask_index;
gsl::span<const int64_t> mask_index_dims;
const T* past;
const T* extra_add_qk;
T* workspace;
T* output;
T* present;
};
template <typename T>
Status QkvToContext(
const hipDeviceProp_t& prop,
rocblas_handle& rocblas,
hipStream_t stream,
contrib::AttentionParameters& parameters,
AttentionData<T>& data,
void* fused_runner);
Status LaunchDecoderAttentionKernel(
const hipDeviceProp_t& prop, // Device Properties
hipStream_t stream, // Cuda stream
rocblas_handle& rocblas, // Rocblas handle
const size_t element_size, // Element size of input tensor
const int batch_size, // Batch size (B)
const int sequence_length, // Sequence length (S)
const int kv_sequence_length, // Key/Value/Cache sequence length
const int num_heads, // Number of attention heads (N)
const int head_size, // Hidden size per head (H)
const bool static_kv, // Whether cross attention or not
const bool use_past, // Whether use cache or not
const bool has_layer_state, // Whether output cache or not
const bool has_key_padding_mask, // Whether use key_padding_mask or not
const void* gemm_query_buffer, // Query buffer
const void* gemm_kv_buffer, // Key and value buffer
const bool* key_padding_mask, // Key padding mask
const void* key_cache, // Input key cache
const void* value_cache, // Input value cache
void* qkv_buffer, // Temporary buffer
void* workspace_buffer, // Temporary buffer
void* output, // Output tensor
void* new_key_cache, // New_key_cache tensor
void* new_value_cache // New_value_cache tensor
);
// BxNxSxH => BxSxNxH or SxBxNxH (reversed_bs is true)
Status LaunchTransCtx(hipStream_t stream,
const int sequence_length, const int batch_size, const int head_size, const int num_heads,
const int max_threads_per_block, const bool reversed_bs, const float* input, float* output);
Status LaunchTransCtx(hipStream_t stream,
const int sequence_length, const int batch_size, const int head_size, const int num_heads,
const int max_threads_per_block, const bool reversed_bs, const half* input, half* output);
// BxSxMxNxH or SxBxMxNxH (reversed_bs is true) => MxBxNxSxH
Status LaunchTransQkv(hipStream_t stream, const int matrix_num,
const int sequence_length, const int batch_size, const int head_size, const int num_heads,
const int max_threads_per_block, const bool reversed_bs, const float* input, float* output);
Status LaunchTransQkv(hipStream_t stream, const int matrix_num,
const int sequence_length, const int batch_size, const int head_size, const int num_heads,
const int max_threads_per_block, const bool reversed_bs, const half* input, half* output);
Status LaunchConcatTensorToTensor(hipStream_t stream,
const int all_sequence_length,
const int sequence_length,
const int batch_size,
const int head_size,
const int num_heads,
const int max_threads_per_block,
const int matrix_num,
const float* tensor_in,
const float* tensor_add,
float* tensor_out);
Status LaunchConcatTensorToTensor(hipStream_t stream,
const int all_sequence_length,
const int sequence_length,
const int batch_size,
const int head_size,
const int num_heads,
const int max_threads_per_block,
const int matrix_num,
const half* tensor_in,
const half* tensor_add,
half* tensor_out);
Status LaunchConcatPastToPresent(hipStream_t stream,
const int all_sequence_length,
const int sequence_length,
const int batch_size,
const int head_size,
const int num_heads,
const int max_threads_per_block,
const float* past,
const float* k_v,
float* present);
Status LaunchConcatPastToPresent(hipStream_t stream,
const int all_sequence_length,
const int sequence_length,
const int batch_size,
const int head_size,
const int num_heads,
const int max_threads_per_block,
const half* past,
const half* k_v,
half* present);
} // namespace rocm
} // namespace contrib
} // namespace onnxruntime
#include "hip/hip_runtime.h"
/*
The implementation of this file is based on qkvToContext plugin in TensorRT demo:
https://github.com/NVIDIA/TensorRT/tree/release/5.1/demo/BERT/
Copyright 2019 NVIDIA Corporation
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
// Modifications: add transpose kernels for TRT format
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "core/providers/rocm/rocm_common.h"
#include "contrib_ops/rocm/bert/attention_impl.h"
using namespace onnxruntime::rocm;
namespace onnxruntime {
namespace contrib {
namespace rocm {
template <typename T>
__global__ void TransposeCtx(const int H, const bool reversed_bs, const T* input, T* output) {
// Input: BxNxSxH
// Output: BxSxNxH
int n = threadIdx.y;
int s = blockIdx.x;
int b = blockIdx.y;
int num_heads = blockDim.y;
int sequence_length = gridDim.x;
const int NH = num_heads * H;
const int NHS = NH * sequence_length;
const int in_offset = s * H + n * sequence_length * H + b * NHS;
int out_offset = 0;
if (reversed_bs) {
const int batch_size = gridDim.y;
const int BNH = NH * batch_size;
out_offset = n * H + b * NH + s * BNH;
} else {
out_offset = n * H + s * NH + b * NHS;
}
const int i = threadIdx.x;
if (i < H) {
output[out_offset + i] = input[in_offset + i];
}
}
template <typename T>
__global__ void TransposeCtxLarge(const int H, const bool reversed_bs, const T* input, T* output) {
// Use when (H*)*num_heads > 1024
// Input: BxNxSxH
// Output: BxSxNxH
int n = threadIdx.y;
int s = blockIdx.x;
int b = blockIdx.y;
int stride = blockDim.x;
int num_heads = blockDim.y;
int sequence_length = gridDim.x;
const int NH = num_heads * H;
const int NHS = NH * sequence_length;
const int in_offset = s * H + n * sequence_length * H + b * NHS;
int out_offset = 0;
if (reversed_bs) {
const int batch_size = gridDim.y;
const int BNH = NH * batch_size;
out_offset = n * H + b * NH + s * BNH;
} else {
out_offset = n * H + s * NH + b * NHS;
}
int i = threadIdx.x;
while (i < H) {
output[out_offset + i] = input[in_offset + i];
i += stride;
}
}
Status LaunchTransCtx(hipStream_t stream,
const int sequence_length, const int batch_size, const int head_size, const int num_heads,
const int max_threads_per_block, const bool reversed_bs, const float* input, float* output) {
const dim3 grid(sequence_length, batch_size, 1);
if (0 == (head_size & 1)) {
const int H = head_size / 2;
const float2* input2 = reinterpret_cast<const float2*>(input);
float2* output2 = reinterpret_cast<float2*>(output);
if (H * num_heads <= max_threads_per_block) {
const dim3 block(H, num_heads, 1);
hipLaunchKernelGGL(HIP_KERNEL_NAME(TransposeCtx<float2>), grid, block, 0, stream, H, reversed_bs, input2, output2);
} else {
const dim3 block(max_threads_per_block / num_heads, num_heads, 1);
hipLaunchKernelGGL(HIP_KERNEL_NAME(TransposeCtxLarge<float2>), grid, block, 0, stream, H, reversed_bs, input2, output2);
}
} else {
if (head_size * num_heads <= max_threads_per_block) {
const dim3 block(head_size, num_heads, 1);
hipLaunchKernelGGL(HIP_KERNEL_NAME(TransposeCtx<float>), grid, block, 0, stream, head_size, reversed_bs, input, output);
} else {
const dim3 block(max_threads_per_block / num_heads, num_heads, 1);
hipLaunchKernelGGL(HIP_KERNEL_NAME(TransposeCtxLarge<float>), grid, block, 0, stream, head_size, reversed_bs, input, output);
}
}
return HIP_CALL(hipGetLastError());
}
Status LaunchTransCtx(hipStream_t stream,
const int sequence_length, const int batch_size, const int head_size, const int num_heads,
const int max_threads_per_block, const bool reversed_bs, const half* input, half* output) {
const dim3 grid(sequence_length, batch_size, 1);
if (0 == (head_size % 4)) {
const int H = head_size / 4;
const float2* input2 = reinterpret_cast<const float2*>(input);
float2* output2 = reinterpret_cast<float2*>(output);
if (H * num_heads <= max_threads_per_block) {
const dim3 block(H, num_heads, 1);
hipLaunchKernelGGL(HIP_KERNEL_NAME(TransposeCtx<float2>), grid, block, 0, stream, H, reversed_bs, input2, output2);
} else {
const dim3 block(max_threads_per_block / num_heads, num_heads, 1);
hipLaunchKernelGGL(HIP_KERNEL_NAME(TransposeCtxLarge<float2>), grid, block, 0, stream, H, reversed_bs, input2, output2);
}
} else if (0 == (head_size & 1)) {
const int H = head_size / 2;
const half2* input2 = reinterpret_cast<const half2*>(input);
half2* output2 = reinterpret_cast<half2*>(output);
if (H * num_heads <= max_threads_per_block) {
const dim3 block(H, num_heads, 1);
hipLaunchKernelGGL(HIP_KERNEL_NAME(TransposeCtx<half2>), grid, block, 0, stream, H, reversed_bs, input2, output2);
} else {
const dim3 block(max_threads_per_block / num_heads, num_heads, 1);
hipLaunchKernelGGL(HIP_KERNEL_NAME(TransposeCtxLarge<half2>), grid, block, 0, stream, H, reversed_bs, input2, output2);
}
} else { // this should be an "odd" case. probably not worth catching it in the half2 kernel.
if (head_size * num_heads <= max_threads_per_block) {
const dim3 block(head_size, num_heads, 1);
hipLaunchKernelGGL(HIP_KERNEL_NAME(TransposeCtx<half>), grid, block, 0, stream, head_size, reversed_bs, input, output);
} else {
const dim3 block(max_threads_per_block / num_heads, num_heads, 1);
hipLaunchKernelGGL(HIP_KERNEL_NAME(TransposeCtxLarge<half>), grid, block, 0, stream, head_size, reversed_bs, input, output);
}
}
return HIP_CALL(hipGetLastError());
}
template <typename T>
__global__ void TransposeQKV(const int H, const bool reversed_bs, const T* input, T* output) {
// Input: BxSxKxNxH or SxBxKxNxH
// Output: KxBxNxSxH
// K is the number of identical matrix
int n = threadIdx.y;
int s = blockIdx.x;
int b = blockIdx.y;
int m = blockIdx.z; // matrix id
const int num_heads = blockDim.y;
const int sequence_length = gridDim.x;
const int batch_size = gridDim.y;
const int chunk_num = gridDim.z;
const int NH = num_heads * H;
const int NHS = NH * sequence_length;
int in_offset = 0;
if (reversed_bs) {
const int BNH = NH * batch_size;
in_offset = n * H + (m + b * chunk_num) * NH + s * BNH * chunk_num;
} else {
in_offset = n * H + (m + s * chunk_num) * NH + b * NHS * chunk_num;
}
const int out_offset = s * H + n * sequence_length * H + b * NHS + m * NHS * batch_size;
const int i = threadIdx.x;
if (i < H) {
output[out_offset + i] = input[in_offset + i];
}
}
template <typename T>
__global__ void TransposeQKVLarge(const int H, const bool reversed_bs, const T* input, T* output) {
// Use when (H*)*num_heads > 1024
// Input: BxSxKxNxH or SxBxKxNxH
// Output: KxBxNxSxH
// K is the number of identical matrix
int n = threadIdx.y;
int s = blockIdx.x;
int b = blockIdx.y;
int m = blockIdx.z; // matrix id
const int stride = blockDim.x;
const int num_heads = blockDim.y;
const int sequence_length = gridDim.x;
const int batch_size = gridDim.y;
const int chunk_num = gridDim.z;
const int NH = num_heads * H;
const int NHS = NH * sequence_length;
int in_offset = 0;
if (reversed_bs) {
const int BNH = NH * batch_size;
in_offset = n * H + (m + b * chunk_num) * NH + s * BNH * chunk_num;
} else {
in_offset = n * H + (m + s * chunk_num) * NH + b * NHS * chunk_num;
}
const int out_offset = s * H + n * sequence_length * H + b * NHS + m * NHS * batch_size;
int i = threadIdx.x;
while (i < H) {
output[out_offset + i] = input[in_offset + i];
i += stride;
}
}
Status LaunchTransQkv(hipStream_t stream, const int matrix_num,
const int sequence_length, const int batch_size, const int head_size, const int num_heads,
const int max_threads_per_block, const bool reversed_bs, const float* input, float* output) {
const dim3 grid(sequence_length, batch_size, matrix_num);
if (0 == (head_size & 1)) {
const int H = head_size / 2;
const float2* input2 = reinterpret_cast<const float2*>(input);
float2* output2 = reinterpret_cast<float2*>(output);
if (H * num_heads <= max_threads_per_block) {
const dim3 block(H, num_heads, 1);
hipLaunchKernelGGL(HIP_KERNEL_NAME(TransposeQKV<float2>), grid, block, 0, stream, H, reversed_bs, input2, output2);
} else {
const dim3 block(max_threads_per_block / num_heads, num_heads, 1);
hipLaunchKernelGGL(HIP_KERNEL_NAME(TransposeQKVLarge<float2>), grid, block, 0, stream, H, reversed_bs, input2, output2);
}
} else {
if (head_size * num_heads <= max_threads_per_block) {
const dim3 block(head_size, num_heads, 1);
hipLaunchKernelGGL(HIP_KERNEL_NAME(TransposeQKV<float>), grid, block, 0, stream, head_size, reversed_bs, input, output);
} else {
const dim3 block(max_threads_per_block / num_heads, num_heads, 1);
hipLaunchKernelGGL(HIP_KERNEL_NAME(TransposeQKVLarge<float>), grid, block, 0, stream, head_size, reversed_bs, input, output);
}
}
return HIP_CALL(hipGetLastError());
}
Status LaunchTransQkv(hipStream_t stream, const int matrix_num,
const int sequence_length, const int batch_size, const int head_size, const int num_heads,
const int max_threads_per_block, const bool reversed_bs, const half* input, half* output) {
const dim3 grid(sequence_length, batch_size, matrix_num);
if (0 == (head_size % 4)) {
const int H = head_size / 4;
const float2* input2 = reinterpret_cast<const float2*>(input);
float2* output2 = reinterpret_cast<float2*>(output);
if (H * num_heads <= max_threads_per_block) {
const dim3 block(H, num_heads, 1);
hipLaunchKernelGGL(HIP_KERNEL_NAME(TransposeQKV<float2>), grid, block, 0, stream, H, reversed_bs, input2, output2);
} else {
const dim3 block(max_threads_per_block / num_heads, num_heads, 1);
hipLaunchKernelGGL(HIP_KERNEL_NAME(TransposeQKVLarge<float2>), grid, block, 0, stream, H, reversed_bs, input2, output2);
}
} else if (0 == (head_size & 1)) {
const int H = head_size / 2;
const half2* input2 = reinterpret_cast<const half2*>(input);
half2* output2 = reinterpret_cast<half2*>(output);
if (H * num_heads <= max_threads_per_block) {
const dim3 block(H, num_heads, 1);
hipLaunchKernelGGL(HIP_KERNEL_NAME(TransposeQKV<half2>), grid, block, 0, stream, H, reversed_bs, input2, output2);
} else {
const dim3 block(max_threads_per_block / num_heads, num_heads, 1);
hipLaunchKernelGGL(HIP_KERNEL_NAME(TransposeQKVLarge<half2>), grid, block, 0, stream, H, reversed_bs, input2, output2);
}
} else { // this should be an "odd" case. probably not worth catching it in the half2 kernel..
if (head_size * num_heads <= max_threads_per_block) {
const dim3 block(head_size, num_heads, 1);
hipLaunchKernelGGL(HIP_KERNEL_NAME(TransposeQKV<half>), grid, block, 0, stream, head_size, reversed_bs, input, output);
} else {
const dim3 block(max_threads_per_block / num_heads, num_heads, 1);
hipLaunchKernelGGL(HIP_KERNEL_NAME(TransposeQKVLarge<half>), grid, block, 0, stream, head_size, reversed_bs, input, output);
}
}
return HIP_CALL(hipGetLastError());
}
} // namespace rocm
} // namespace contrib
} // namespace onnxruntime
#include "hip/hip_runtime.h"
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
// TrtSequenceOffset kernels are modified from FasterTransformer
/*
* Copyright (c) 2019-2021, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "core/providers/rocm/rocm_common.h"
#include "contrib_ops/rocm/bert/bert_padding.h"
using namespace onnxruntime::rocm;
namespace onnxruntime {
namespace contrib {
namespace rocm {
constexpr int32_t kMAX_THREADS_PER_BLOCK = 256;
// -----------------------------------
// Get indices of non-padding tokens and padding tokens. Here we assume that padding is on the right side of sequence.
// sequence_token_count is number of non-padding tokens per sequence, and it has shape [batch_size].
// For example, we have 3 sequences with 1, 2, 4 non-padding tokens and positions like the following (* means padding):
// Sequence_0: 0, 1*, 2*, 3*
// Sequence_1: 4, 5, 6*, 7*
// Sequence_2: 8, 9, 10, 11
// token_offset: 0, 4, 5, 8, 9, 10, 11, 1*, 2*, 3*, 6*, 7*
// token_count_buffer has two numbers for non-padding tokens:
// total_token_count: 1 + 2 + 4 = 7
// max_token_count: 4
// cumulated_token_count: 0, 1, 1+2, 1+2+4
__global__ void getTokenOffset(int* token_count_buffer,
int* token_offset,
int* cumulated_token_count,
const int* sequence_token_count,
const int batch_size,
const int sequence_length) {
// Find offset of non-padding tokens, and max sequence length among all batches
// TODO(tianleiwu): Use hipcub::DevicePartition::Flagged like BuildGlobalIndex in longformer_global_impl.cu
// to build token_offset when sequence length is large.
int total_tokens = 0;
int max_tokens = 0;
int index = 0;
cumulated_token_count[0] = 0;
for (int i = 0; i < batch_size; i++) {
const int count = sequence_token_count[i];
if (count > max_tokens) {
max_tokens = count;
}
cumulated_token_count[i + 1] = cumulated_token_count[i] + count;
for (int j = 0; j < count; j++) {
token_offset[index] = i * sequence_length + j;
index++;
}
total_tokens += count;
}
// Offset of paddings
for (int i = 0; i < batch_size; i++) {
const int count = sequence_token_count[i];
for (int j = 0; j < sequence_length - count; j++) {
token_offset[index] = i * sequence_length + count + j;
index++;
}
}
token_count_buffer[0] = total_tokens;
token_count_buffer[1] = max_tokens;
}
void LaunchGetTokenOffset(int* token_count_buffer,
int* token_offset,
int* cumulated_token_count,
const int* sequence_token_count,
const int batch_size,
const int sequence_length,
hipStream_t stream) {
hipLaunchKernelGGL(getTokenOffset, 1, 1, 0, stream,
token_count_buffer, token_offset, cumulated_token_count, sequence_token_count, batch_size, sequence_length);
}
// -----------------------------------
// Remove paddings
template <typename T>
__global__ void __launch_bounds__(kMAX_THREADS_PER_BLOCK)
removePadding(T* target, const T* source, const int* token_offset, const int width) {
const int tid = threadIdx.x;
const int token_index = blockIdx.x;
const int source_offset = token_offset[token_index];
const int target_offset = token_index;
for (int i = tid; i < width; i += blockDim.x) {
target[target_offset * width + i] = source[source_offset * width + i];
}
}
template <>
void LaunchRemovePadding(
half* output, const half* input, const int* token_offset, const int token_count, const int hidden_size,
hipStream_t stream) {
// input: [batch_size, sequence_length, hidden_size]
// output: [token_count, hidden_size]
// Make sure memory is aligned to 128 bit
ORT_ENFORCE(!(reinterpret_cast<size_t>(input) & 0xF) && !(reinterpret_cast<size_t>(output) & 0xF), "alignment");
if (hidden_size % 8 == 0) {
const int width = hidden_size / 8;
const int4* input2 = reinterpret_cast<const int4*>(input);
int4* output2 = reinterpret_cast<int4*>(output);
hipLaunchKernelGGL(HIP_KERNEL_NAME(removePadding<int4>), token_count, kMAX_THREADS_PER_BLOCK, 0, stream,
output2, input2, token_offset, width);
} else if (hidden_size % 4 == 0) {
const int width = hidden_size / 4;
const int64_t* input2 = reinterpret_cast<const int64_t*>(input);
int64_t* output2 = reinterpret_cast<int64_t*>(output);
hipLaunchKernelGGL(HIP_KERNEL_NAME(removePadding<int64_t>), token_count, kMAX_THREADS_PER_BLOCK, 0, stream,
output2, input2, token_offset, width);
} else if (hidden_size % 2 == 0) {
const int width = hidden_size / 2;
const int32_t* input2 = reinterpret_cast<const int32_t*>(input);
int32_t* output2 = reinterpret_cast<int32_t*>(output);
hipLaunchKernelGGL(HIP_KERNEL_NAME(removePadding<int32_t>), token_count, kMAX_THREADS_PER_BLOCK, 0, stream,
output2, input2, token_offset, width);
} else {
const int width = hidden_size;
const int16_t* input2 = reinterpret_cast<const int16_t*>(input);
int16_t* output2 = reinterpret_cast<int16_t*>(output);
hipLaunchKernelGGL(HIP_KERNEL_NAME(removePadding<int16_t>), token_count, kMAX_THREADS_PER_BLOCK, 0, stream,
output2, input2, token_offset, width);
}
}
template <>
void LaunchRemovePadding(
float* output, const float* input, const int* token_offset, const int token_count, const int hidden_size,
hipStream_t stream) {
ORT_ENFORCE(!(reinterpret_cast<size_t>(input) & 0xF) && !(reinterpret_cast<size_t>(output) & 0xF), "alignment");
if (hidden_size % 4 == 0) {
const int width = hidden_size / 4;
const int4* input2 = reinterpret_cast<const int4*>(input);
int4* output2 = reinterpret_cast<int4*>(output);
hipLaunchKernelGGL(HIP_KERNEL_NAME(removePadding<int4>), token_count, kMAX_THREADS_PER_BLOCK, 0, stream, output2, input2, token_offset, width);
} else if (hidden_size % 2 == 0) {
const int width = hidden_size / 2;
const int64_t* input2 = reinterpret_cast<const int64_t*>(input);
int64_t* output2 = reinterpret_cast<int64_t*>(output);
hipLaunchKernelGGL(HIP_KERNEL_NAME(removePadding<int64_t>), token_count, kMAX_THREADS_PER_BLOCK, 0, stream, output2, input2, token_offset, width);
} else {
const int width = hidden_size;
const int32_t* input2 = reinterpret_cast<const int32_t*>(input);
int32_t* output2 = reinterpret_cast<int32_t*>(output);
hipLaunchKernelGGL(HIP_KERNEL_NAME(removePadding<int32_t>), token_count, kMAX_THREADS_PER_BLOCK, 0, stream, output2, input2, token_offset, width);
}
}
// -----------------------------------
// Recover padding.
template <typename T>
__global__ void __launch_bounds__(kMAX_THREADS_PER_BLOCK)
restorePadding(T* target, const T* source, const int* token_offset, const int width, const int token_count) {
const int tid = threadIdx.x;
const int token_index = blockIdx.x;
const int target_seq_id = token_offset[token_index];
const int source_seq_id = token_index;
constexpr T padding_zero = 0;
if (token_index < token_count) {
for (int i = tid; i < width; i += blockDim.x) {
target[target_seq_id * width + i] = source[source_seq_id * width + i];
}
} else {
// It is padding: fill with zeros
for (int i = tid; i < width; i += blockDim.x) {
target[target_seq_id * width + i] = padding_zero;
}
}
}
template <>
__global__ void __launch_bounds__(kMAX_THREADS_PER_BLOCK)
restorePadding(int4* target, const int4* source, const int* token_offset, const int width, const int token_count) {
const int tid = threadIdx.x;
const int token_index = blockIdx.x;
const int target_seq_id = token_offset[token_index];
const int source_seq_id = token_index;
int4 padding_zero{0, 0, 0, 0};
if (token_index < token_count) {
for (int i = tid; i < width; i += blockDim.x) {
target[target_seq_id * width + i] = source[source_seq_id * width + i];
}
} else {
// It is padding: fill with zeros
for (int i = tid; i < width; i += blockDim.x) {
target[target_seq_id * width + i] = padding_zero;
}
}
}
template <>
void LaunchRestorePadding(
float* output, const float* input, const int* token_offset, const int token_count, const int hidden_size,
const int batch_size, const int sequence_length,
hipStream_t stream) {
ORT_ENFORCE(!(reinterpret_cast<size_t>(input) & 0xF) && !(reinterpret_cast<size_t>(output) & 0xF), "alignment");
int grid_size = batch_size * sequence_length;
if (hidden_size % 4 == 0) {
const int width = hidden_size / 4;
const int4* input2 = reinterpret_cast<const int4*>(input);
int4* output2 = reinterpret_cast<int4*>(output);
hipLaunchKernelGGL(HIP_KERNEL_NAME(restorePadding<int4>), grid_size, kMAX_THREADS_PER_BLOCK, 0, stream,
output2, input2, token_offset, width, token_count);
} else if (hidden_size % 2 == 0) {
const int width = hidden_size / 2;
const int64_t* input2 = reinterpret_cast<const int64_t*>(input);
int64_t* output2 = reinterpret_cast<int64_t*>(output);
hipLaunchKernelGGL(HIP_KERNEL_NAME(restorePadding<int64_t>), grid_size, kMAX_THREADS_PER_BLOCK, 0, stream,
output2, input2, token_offset, width, token_count);
} else {
const int width = hidden_size;
const int32_t* input2 = reinterpret_cast<const int32_t*>(input);
int32_t* output2 = reinterpret_cast<int32_t*>(output);
hipLaunchKernelGGL(HIP_KERNEL_NAME(restorePadding<int32_t>), grid_size, kMAX_THREADS_PER_BLOCK, 0, stream,
output2, input2, token_offset, width, token_count);
}
}
template <>
void LaunchRestorePadding(
half* output, const half* input, const int* token_offset, const int token_count, const int hidden_size,
const int batch_size, const int sequence_length,
hipStream_t stream) {
// input: [token_count, hidden_size]
// output: [batch_size, sequence_length, hidden_size]
ORT_ENFORCE(!(reinterpret_cast<size_t>(input) & 0xF) && !(reinterpret_cast<size_t>(output) & 0xF), "alignment");
int grid_size = batch_size * sequence_length;
if (hidden_size % 8 == 0) {
const int width = hidden_size / 8;
const int4* input2 = reinterpret_cast<const int4*>(input);
int4* output2 = reinterpret_cast<int4*>(output);
hipLaunchKernelGGL(HIP_KERNEL_NAME(restorePadding<int4>), grid_size, kMAX_THREADS_PER_BLOCK, 0, stream,
output2, input2, token_offset, width, token_count);
} else if (hidden_size % 4 == 0) {
const int width = hidden_size / 4;
const int64_t* input2 = reinterpret_cast<const int64_t*>(input);
int64_t* output2 = reinterpret_cast<int64_t*>(output);
hipLaunchKernelGGL(HIP_KERNEL_NAME(restorePadding<int64_t>), grid_size, kMAX_THREADS_PER_BLOCK, 0, stream,
output2, input2, token_offset, width, token_count);
} else if (hidden_size % 2 == 0) {
const int width = hidden_size / 2;
const int32_t* input2 = reinterpret_cast<const int32_t*>(input);
int32_t* output2 = reinterpret_cast<int32_t*>(output);
hipLaunchKernelGGL(HIP_KERNEL_NAME(restorePadding<int32_t>), grid_size, kMAX_THREADS_PER_BLOCK, 0, stream,
output2, input2, token_offset, width, token_count);
} else {
const int width = hidden_size;
const int16_t* input2 = reinterpret_cast<const int16_t*>(input);
int16_t* output2 = reinterpret_cast<int16_t*>(output);
hipLaunchKernelGGL(HIP_KERNEL_NAME(restorePadding<int16_t>), grid_size, kMAX_THREADS_PER_BLOCK, 0, stream,
output2, input2, token_offset, width, token_count);
}
}
__global__ void __launch_bounds__(kMAX_THREADS_PER_BLOCK)
getTrtSequenceOffset(int* trt_mha_padding_offset,
const int* sequence_token_count,
const int batch_size) {
extern __shared__ int tmp_offset[];
if (threadIdx.x == 0) {
tmp_offset[0] = 0;
for (int i = 0; i < batch_size; i++) {
tmp_offset[i + 1] = tmp_offset[i] + sequence_token_count[i];
}
}
__syncthreads();
for (int i = threadIdx.x; i < batch_size + 1; i += blockDim.x) {
trt_mha_padding_offset[i] = tmp_offset[i];
}
}
// Get sequence offset for TensorRT fused attention when there is no padding (or padding is removed)
void LaunchTrtSequenceOffset(int* trt_mha_padding_offset,
const int* sequence_token_count,
const int batch_size,
hipStream_t stream) {
hipLaunchKernelGGL(getTrtSequenceOffset, 1, kMAX_THREADS_PER_BLOCK, sizeof(int) * (batch_size + 1), stream,
trt_mha_padding_offset, sequence_token_count, batch_size);
}
__global__ void __launch_bounds__(kMAX_THREADS_PER_BLOCK)
getTrtSequenceOffset(int* trt_mha_padding_offset,
const int* sequence_token_count,
const int batch_size,
const int sequence_length) {
extern __shared__ int tmp_offset[];
if (threadIdx.x == 0) {
tmp_offset[0] = 0;
// B for fused attention is 2 * batch_size
for (int i = 0; i < batch_size; i++) {
tmp_offset[i * 2 + 1] = tmp_offset[i * 2] + sequence_token_count[i];
tmp_offset[i * 2 + 2] = sequence_length * (i + 1);
}
}
__syncthreads();
for (int i = threadIdx.x; i < 2 * batch_size + 1; i += blockDim.x) {
trt_mha_padding_offset[i] = tmp_offset[i];
}
}
// Get sequence offset for TensorRT fused attention when we keep the padding
void LaunchTrtSequenceOffset(int* trt_mha_padding_offset,
const int* sequence_token_count,
const int batch_size,
const int sequence_length,
hipStream_t stream) {
hipLaunchKernelGGL(getTrtSequenceOffset, 1, kMAX_THREADS_PER_BLOCK, sizeof(int) * (2 * batch_size + 1), stream,
trt_mha_padding_offset, sequence_token_count, batch_size, sequence_length);
}
} // namespace rocm
} // namespace contrib
} // namespace onnxruntime
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include "core/providers/rocm/shared_inc/rocm_utils.h"
#include <hip/hip_fp16.h>
#include <rocblas.h>
namespace onnxruntime {
namespace contrib {
namespace rocm {
// Build token indice for non-padding tokens and padding tokens.
void LaunchGetTokenOffset(int* token_count_buffer,
int* token_offset,
int* cumulated_token_count,
const int* sequence_token_count,
const int batch_size,
const int sequence_length,
hipStream_t stream);
// Remove paddings from input.
template <typename T>
void LaunchRemovePadding(
T* output, const T* input, const int* token_offset, const int token_count, const int hidden_size,
hipStream_t stream);
// Rebuild paddings to restore output shape.
template <typename T>
void LaunchRestorePadding(
T* output, const T* input, const int* token_offset, const int token_count, const int hidden_size,
const int batch_size, const int sequence_length,
hipStream_t stream);
// Padding offset for TensorRT fused attention kernel
void LaunchTrtSequenceOffset(int* trt_mha_padding_offset,
const int* mask_index,
const int batch_size,
hipStream_t stream);
void LaunchTrtSequenceOffset(int* trt_mha_padding_offset,
const int* mask_index,
const int batch_size,
const int sequence_length,
hipStream_t stream);
} // namespace rocm
} // namespace contrib
} // namespace onnxruntime
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "contrib_ops/rocm/bert/attention_impl.h"
#include "contrib_ops/rocm/bert/decoder_attention.h"
#include "contrib_ops/rocm/bert/transformer_rocm_common.h"
#include "core/framework/op_kernel.h"
#include "core/providers/rocm/shared_inc/fpgeneric.h"
using namespace onnxruntime::rocm;
using namespace ::onnxruntime::common;
using namespace ONNX_NAMESPACE;
namespace onnxruntime {
namespace contrib {
namespace rocm {
#define REGISTER_KERNEL_TYPED(T) \
ONNX_OPERATOR_TYPED_KERNEL_EX( \
DecoderAttention, \
kMSDomain, \
1, \
T, \
kRocmExecutionProvider, \
(*KernelDefBuilder::Create()) \
.TypeConstraint("T", DataTypeImpl::GetTensorType<T>()), \
DecoderAttention<T>);
REGISTER_KERNEL_TYPED(float)
REGISTER_KERNEL_TYPED(MLFloat16)
namespace {
Status CheckInputs(const TensorShape& query_shape,
const TensorShape& key_shape,
const TensorShape& q_weights_shape,
const TensorShape& kv_weights_shape,
const TensorShape& bias_shape,
const Tensor* key_padding_mask,
const Tensor* key_cache,
const Tensor* value_cache,
const bool static_kv,
const bool use_past,
const bool has_layer_state,
const bool has_key_padding_mask) {
const auto& query_shape_dims = query_shape.GetDims();
if (query_shape_dims.size() != 3) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'query' is expected to have 3 dimensions, got ",
query_shape_dims.size());
}
int sequence_length = static_cast<int>(query_shape_dims[0]);
int batch_size = static_cast<int>(query_shape_dims[1]);
int hidden_size = static_cast<int>(query_shape_dims[2]);
const auto& key_shape_dims = key_shape.GetDims();
if (key_shape_dims.size() != 3) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'key' is expected to have 3 dimensions, got ",
key_shape_dims.size());
}
int kv_sequence_length = static_cast<int>(key_shape_dims[0]);
if (query_shape_dims[1] != key_shape_dims[1]) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "query and key shall have the same batch size");
}
if (query_shape_dims[2] != key_shape_dims[2]) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "query and key shall have the same hidden size");
}
const auto& q_weights_dims = q_weights_shape.GetDims();
if (q_weights_dims.size() != 2) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'q_weights' is expected to have 2 dimensions, got ",
q_weights_dims.size());
}
const auto& kv_weights_dims = kv_weights_shape.GetDims();
if (kv_weights_dims.size() != 2) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'kv_weights' is expected to have 2 dimensions, got ",
kv_weights_dims.size());
}
if (q_weights_dims[0] != hidden_size || q_weights_dims[1] != hidden_size) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "q_weights shall have shape (hidden size, hidden size)");
}
if (kv_weights_dims[0] != hidden_size || kv_weights_dims[1] != 2 * static_cast<int64_t>(hidden_size)) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "kv_weights shall have shape (hidden size, 2 * hidden size)");
}
const auto& bias_dims = bias_shape.GetDims();
if (bias_dims.size() != 1) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'bias' is expected to have 1 dimension, got ",
bias_dims.size());
}
if (bias_dims[0] != 3 * static_cast<int64_t>(hidden_size)) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "bias shall have shape (3 * hidden size)");
}
int key_length = kv_sequence_length;
if (key_padding_mask != nullptr && has_key_padding_mask == true) {
const auto& kp_mask_dims = key_padding_mask->Shape().GetDims();
if (kp_mask_dims.size() != 2) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"Input 'key_padding_mask' is expected to have 2 dimension, got ",
kp_mask_dims.size());
}
if (kp_mask_dims[0] != batch_size) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "key_padding_mask shall have same batch size with query");
}
if (!has_layer_state || !use_past) {
if (!static_kv) {
key_length = sequence_length;
}
} else {
if (!static_kv) {
key_length = sequence_length + kv_sequence_length;
}
}
if (kp_mask_dims[1] != key_length) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"key_padding_mask shall have same sequence length as generated key");
}
}
if (key_cache != nullptr && value_cache != nullptr && has_layer_state && use_past) {
const auto& key_cache_dims = key_cache->Shape().GetDims();
if (key_cache_dims.size() != 4) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'key_cache' is expected to have 4 dimension, got ",
key_cache_dims.size());
}
const auto& value_cache_dims = value_cache->Shape().GetDims();
if (value_cache_dims.size() != 4) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'value_cache' is expected to have 4 dimension, got ",
value_cache_dims.size());
}
if (key_cache_dims[0] != batch_size) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "key_cache shall have same batch size as query");
}
if (value_cache_dims[0] != batch_size) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "value_cache shall have same batch size as query");
}
if (key_cache_dims[1] * key_cache_dims[3] != hidden_size) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "key_cache shall have correct hidden size");
}
if (value_cache_dims[1] * value_cache_dims[3] != hidden_size) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "value_cache shall have correct hidden size");
}
}
return Status::OK();
}
} // anonymous namespace
template <typename T>
DecoderAttention<T>::DecoderAttention(const OpKernelInfo& info) : RocmKernel(info) {
int64_t num_heads = 0;
ORT_ENFORCE(info.GetAttr("num_heads", &num_heads).IsOK() && num_heads > 0);
num_heads_ = static_cast<int>(num_heads);
}
template <typename T>
Status DecoderAttention<T>::ComputeInternal(OpKernelContext* context) const {
const Tensor* query(context->Input<Tensor>(0));
const Tensor* key(context->Input<Tensor>(1));
const Tensor* q_weights(context->Input<Tensor>(2));
const Tensor* kv_weights(context->Input<Tensor>(3));
const Tensor* bias(context->Input<Tensor>(4));
const Tensor* key_padding_mask(context->Input<Tensor>(5));
const Tensor* key_cache(context->Input<Tensor>(6));
const Tensor* value_cache(context->Input<Tensor>(7));
const Tensor* static_kv(context->Input<Tensor>(8));
const Tensor* use_past(context->Input<Tensor>(9));
const Tensor* has_layer_state(context->Input<Tensor>(10));
const Tensor* has_key_padding_mask(context->Input<Tensor>(11));
hipStream_t stream = Stream();
// Copy static_kv, use_past and has_layer_state to CPU
auto pinned_buffer = AllocateBufferOnCPUPinned<void>(4 * sizeof(bool));
bool* kernel_state_pinned = reinterpret_cast<bool*>(pinned_buffer.get());
HIP_RETURN_IF_ERROR(hipMemcpyAsync(kernel_state_pinned, static_kv->Data<bool>(), sizeof(bool),
hipMemcpyDeviceToHost, stream));
HIP_RETURN_IF_ERROR(hipMemcpyAsync(kernel_state_pinned + 1, use_past->Data<bool>(), sizeof(bool),
hipMemcpyDeviceToHost, stream));
HIP_RETURN_IF_ERROR(hipMemcpyAsync(kernel_state_pinned + 2, has_layer_state->Data<bool>(), sizeof(bool),
hipMemcpyDeviceToHost, stream));
HIP_RETURN_IF_ERROR(hipMemcpyAsync(kernel_state_pinned + 3, has_key_padding_mask->Data<bool>(), sizeof(bool),
hipMemcpyDeviceToHost, stream));
// Create an event to make sure the async copy is finished before reading the data.
AutoDestoryCudaEvent new_event;
hipEvent_t& isCopyDone = new_event.Get();
HIP_RETURN_IF_ERROR(hipEventCreate(&isCopyDone));
HIP_RETURN_IF_ERROR(hipEventRecord(isCopyDone, stream));
auto& device_prop = GetDeviceProp();
// query shape (batch_size, sequence_length, input_hidden_size)
const auto& query_shape = query->Shape();
int sequence_length = static_cast<int>(query_shape[0]);
int batch_size = static_cast<int>(query_shape[1]);
int hidden_size = static_cast<int>(query_shape[2]);
const auto& key_shape = key->Shape();
int key_sequence_length = static_cast<int>(key_shape[0]);
int head_size = hidden_size / num_heads_;
//k, v sequence after gemm
int kv_sequence_length = 0;
// Generate q, k, v w/o cache
// query input: (S, B, h1)
// key input: (S', B, h1)
// weight: (h1, h2)
// h = N*H
rocblas_handle rocblas = RocblasHandle();
ROCBLAS_RETURN_IF_ERROR(rocblas_set_stream(rocblas, stream));
constexpr size_t element_size = sizeof(T);
typedef typename ToHipType<T>::MappedType HipT;
HipT one = ToHipType<T>::FromFloat(1.0f);
HipT zero = ToHipType<T>::FromFloat(0.0f);
int m = 0, n = 0, k = 0;
IAllocatorUniquePtr<T> gemm_query_buffer_p(nullptr);
IAllocatorUniquePtr<T> gemm_kv_buffer_p(nullptr);
HIP_RETURN_IF_ERROR(hipEventSynchronize(isCopyDone));
bool static_kv_ = *kernel_state_pinned;
bool use_past_ = *(kernel_state_pinned + 1);
bool has_layer_state_ = *(kernel_state_pinned + 2);
bool has_key_padding_mask_ = *(kernel_state_pinned + 3);
ORT_RETURN_IF_ERROR(
CheckInputs(query->Shape(),
key->Shape(),
q_weights->Shape(),
kv_weights->Shape(),
bias->Shape(),
key_padding_mask,
key_cache,
value_cache,
static_kv_,
use_past_,
has_layer_state_,
has_key_padding_mask_)
);
// calculate q
gemm_query_buffer_p = GetScratchBuffer<T>(batch_size * sequence_length * hidden_size * element_size);
m = sequence_length * batch_size;
n = hidden_size;
k = hidden_size;
// TODO(tianleiwu): fuse bias and transpose
// broadcast bias for query: (h2, S*B)
ROCBLAS_RETURN_IF_ERROR(rocblasGemmHelper(
rocblas, rocblas_operation_none, rocblas_operation_none, n, m, 1, &one,
reinterpret_cast<const HipT*>(bias->Data<T>()), n,
GetConstOnes<HipT>(m), 1,
&zero, reinterpret_cast<HipT*>(gemm_query_buffer_p.get()), n, device_prop));
// matmul: (h2, h1)*(h1, S*B)
ROCBLAS_RETURN_IF_ERROR(rocblasGemmHelper(
rocblas, rocblas_operation_none, rocblas_operation_none, n, m, k, &one,
reinterpret_cast<const HipT*>(q_weights->Data<T>()), n,
reinterpret_cast<const HipT*>(query->Data<T>()), k,
&one, reinterpret_cast<HipT*>(gemm_query_buffer_p.get()), n, device_prop));
// gemm_query_buffer in col-base: (h2, S*B)
// calcualte k, v
n = 2 * hidden_size;
k = hidden_size;
if (!has_layer_state_ || !use_past_) {
if (!static_kv_) {
gemm_kv_buffer_p = GetScratchBuffer<T>(batch_size * 2 * sequence_length * hidden_size * element_size);
m = sequence_length * batch_size;
n = 2 * hidden_size;
k = hidden_size;
kv_sequence_length = sequence_length;
// broadcast bias for key and value: (2*h2, T_S*B)
ROCBLAS_RETURN_IF_ERROR(rocblasGemmHelper(
rocblas, rocblas_operation_none, rocblas_operation_none, n, m, 1, &one,
reinterpret_cast<const HipT*>(bias->Data<T>() + hidden_size), n,
GetConstOnes<HipT>(m), 1,
&zero, reinterpret_cast<HipT*>(gemm_kv_buffer_p.get()), n, device_prop));
// matmul: (2*h2, h1)*(h1, T_S*B)
ROCBLAS_RETURN_IF_ERROR(rocblasGemmHelper(
rocblas, rocblas_operation_none, rocblas_operation_none, n, m, k, &one,
reinterpret_cast<const HipT*>(kv_weights->Data<T>()), n,
reinterpret_cast<const HipT*>(query->Data<T>()), k,
&one, reinterpret_cast<HipT*>(gemm_kv_buffer_p.get()), n, device_prop));
// gemm_kv_buffer in col-base: (2*h2, T_S*B)
} else {
gemm_kv_buffer_p = GetScratchBuffer<T>(batch_size * 2 * key_sequence_length * hidden_size * element_size);
m = key_sequence_length * batch_size;
n = 2 * hidden_size;
k = hidden_size;
kv_sequence_length = key_sequence_length;
// broadcast bias for key and value: (2*h2, T_S*B)
ROCBLAS_RETURN_IF_ERROR(rocblasGemmHelper(
rocblas, rocblas_operation_none, rocblas_operation_none, n, m, 1, &one,
reinterpret_cast<const HipT*>(bias->Data<T>() + hidden_size), n,
GetConstOnes<HipT>(m), 1,
&zero, reinterpret_cast<HipT*>(gemm_kv_buffer_p.get()), n, device_prop));
// matmul: (2*h2, h1)*(h1, T_S*B)
ROCBLAS_RETURN_IF_ERROR(rocblasGemmHelper(
rocblas, rocblas_operation_none, rocblas_operation_none, n, m, k, &one,
reinterpret_cast<const HipT*>(kv_weights->Data<T>()), n,
reinterpret_cast<const HipT*>(key->Data<T>()), k,
&one, reinterpret_cast<HipT*>(gemm_kv_buffer_p.get()), n, device_prop));
// gemm_kv_buffer in col-base: (2*h2, T_S*B)
}
} else {
ORT_ENFORCE(nullptr != key_cache && nullptr != value_cache); // (B, N, S, H)
const auto& cache_shape = key_cache->Shape();
// key and value cache have identical shape
int cache_sequence_length = static_cast<int>(cache_shape[2]);
if (!static_kv_) {
gemm_kv_buffer_p = GetScratchBuffer<T>(batch_size * 2 * sequence_length * hidden_size * element_size);
m = sequence_length * batch_size;
kv_sequence_length = cache_sequence_length + sequence_length;
// broadcast bias for key and value: (2*h2, T_S*B)
ROCBLAS_RETURN_IF_ERROR(rocblasGemmHelper(
rocblas, rocblas_operation_none, rocblas_operation_none, n, m, 1, &one,
reinterpret_cast<const HipT*>(bias->Data<T>() + hidden_size), n,
GetConstOnes<HipT>(m), 1,
&zero, reinterpret_cast<HipT*>(gemm_kv_buffer_p.get()), n, device_prop));
// matmul: (2*h2, h1)*(h1, T_S*B)
ROCBLAS_RETURN_IF_ERROR(rocblasGemmHelper(
rocblas, rocblas_operation_none, rocblas_operation_none, n, m, k, &one,
reinterpret_cast<const HipT*>(kv_weights->Data<T>()), n,
reinterpret_cast<const HipT*>(query->Data<T>()), k,
&one, reinterpret_cast<HipT*>(gemm_kv_buffer_p.get()), n, device_prop));
// gemm_kv_buffer in col-base: (2*h2, T_S*B)
} else {
kv_sequence_length = cache_sequence_length;
}
}
size_t bytes = element_size * batch_size * (sequence_length + 2 * kv_sequence_length) * hidden_size;
auto qkv_buffer_p = GetScratchBuffer<void>(bytes);
bytes = element_size * 2 * batch_size * sequence_length * num_heads_ * (2 * head_size + kv_sequence_length);
auto workspace_p = GetScratchBuffer<void>(bytes);
Tensor* output(context->Output(0, query_shape));
TensorShape new_cache_shape({batch_size, num_heads_, kv_sequence_length, head_size});
Tensor* new_key_cache(context->Output(1, new_cache_shape));
Tensor* new_value_cache(context->Output(2, new_cache_shape));
return LaunchDecoderAttentionKernel(
device_prop,
stream,
rocblas,
element_size,
batch_size,
sequence_length,
kv_sequence_length,
num_heads_,
head_size,
static_kv_,
use_past_,
has_layer_state_,
has_key_padding_mask_,
nullptr == gemm_query_buffer_p ? nullptr : reinterpret_cast<const HipT*>(gemm_query_buffer_p.get()),
nullptr == gemm_kv_buffer_p ? nullptr : reinterpret_cast<const HipT*>(gemm_kv_buffer_p.get()),
nullptr == key_padding_mask ? nullptr : key_padding_mask->Data<bool>(),
nullptr == key_cache ? nullptr : key_cache->Data<T>(),
nullptr == value_cache ? nullptr : value_cache->Data<T>(),
qkv_buffer_p.get(),
workspace_p.get(),
output->MutableData<T>(),
nullptr == new_key_cache ? nullptr : new_key_cache->MutableData<T>(),
nullptr == new_value_cache ? nullptr : new_value_cache->MutableData<T>());
}
} // namespace rocm
} // namespace contrib
} // namespace onnxruntime
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include "core/common/common.h"
#include "core/providers/rocm/rocm_kernel.h"
namespace onnxruntime {
namespace contrib {
namespace rocm {
using namespace onnxruntime::rocm;
template <typename T>
class DecoderAttention final : public RocmKernel {
public:
DecoderAttention(const OpKernelInfo& info);
Status ComputeInternal(OpKernelContext* context) const override;
private:
int num_heads_;
};
} // namespace rocm
} // namespace contrib
} // namespace onnxruntime
#include "hip/hip_runtime.h"
/*
The implementation of this file is based on bert plugins in TensorRT demo:
https://github.com/NVIDIA/TensorRT/tree/release/5.1/demo/BERT/
Copyright 2019 NVIDIA Corporation
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include "core/providers/rocm/rocm_common.h"
#include "core/providers/rocm/cu_inc/common.cuh"
#include "core/providers/rocm/shared_inc/rocm_call.h"
#include <hip/hip_fp16.h>
#include <rocblas.h>
#include <hipcub/hipcub.hpp>
using namespace onnxruntime::rocm;
using namespace hipcub;
namespace onnxruntime {
namespace contrib {
namespace rocm {
template <typename T>
__device__ inline T Rsqrt(const T& x);
template <>
__device__ inline float Rsqrt(const float& x) {
return rsqrtf(x);
}
template <>
__device__ inline half Rsqrt(const half& x) {
#if __CUDA_ARCH__ >= 530 || !defined(__CUDA_ARCH__)
return hrsqrt(x);
#else
return half(rsqrtf(float(x)));
#endif
}
__device__ inline half2 AddHalf2(const half2 a, const half2 b) {
#if __CUDA_ARCH__ >= 530 || !defined(__CUDA_ARCH__)
return __hadd2(a, b);
#else
return __halves2half2(__hadd(a.x, b.x), __hadd(a.y, b.y));
#endif
}
struct KeyValuePairSum {
__device__ inline hipcub::KeyValuePair<float, float> operator()(const hipcub::KeyValuePair<float, float>& a,
const hipcub::KeyValuePair<float, float>& b) {
return hipcub::KeyValuePair<float, float>(a.key + b.key, a.value + b.value);
}
__device__ inline hipcub::KeyValuePair<half, half> operator()(const hipcub::KeyValuePair<half, half>& a,
const hipcub::KeyValuePair<half, half>& b) {
const half2 a2 = __halves2half2(a.key, a.value);
const half2 b2 = __halves2half2(b.key, b.value);
const half2 res = AddHalf2(a2, b2);
return hipcub::KeyValuePair<half, half>(__low2half(res), __high2half(res));
}
__device__ inline hipcub::KeyValuePair<half2, half2> operator()(const hipcub::KeyValuePair<half2, half2>& a,
const hipcub::KeyValuePair<half2, half2>& b) {
return hipcub::KeyValuePair<half2, half2>(AddHalf2(a.key, b.key), AddHalf2(a.value, b.value));
}
};
template <typename T, int TPB>
__device__ inline void LayerNorm(
const hipcub::KeyValuePair<T, T>& thread_data, const int ld, const int offset, const T* beta,
const T* gamma, const T epsilon, T* output) {
// Assuming thread_data is already divided by ld
using BlockReduce = hipcub::BlockReduce<hipcub::KeyValuePair<T, T>, TPB>;
__shared__ typename BlockReduce::TempStorage temp_storage;
__shared__ T mu; // mean
__shared__ T rsigma; // 1 / std.dev.
KeyValuePairSum pair_sum;
const auto sum_kv = BlockReduce(temp_storage).Reduce(thread_data, pair_sum);
if (threadIdx.x == 0) {
mu = sum_kv.key;
rsigma = Rsqrt(sum_kv.value - mu * mu + epsilon);
}
__syncthreads();
for (int i = threadIdx.x; i < ld; i += TPB) {
const int idx = offset + i;
const T val = output[idx];
const T g(gamma[i]);
const T b = (nullptr == beta) ? (T)0 : beta[i];
output[idx] = g * (val - mu) * rsigma + b;
}
}
template <typename T, int TPB, int ILP>
__device__ inline void LayerNormSmall(const T* input_v, const hipcub::KeyValuePair<T, T>& thread_data,
const int ld, const int idx, const T* beta, const T* gamma,
const T epsilon, T* output) {
// Assuming thread_data is already divided by ld
// Small settings: the block covers the leading dimension TPB >= ld. The input
// value is available in a register
using VecT = aligned_vector<T, ILP>;
using BlockReduce = hipcub::BlockReduce<hipcub::KeyValuePair<T, T>, TPB>;
__shared__ typename BlockReduce::TempStorage temp_storage;
__shared__ T mu; // mean
__shared__ T rsigma; // 1 / std.dev.
T beta_v[ILP], gamma_v[ILP], output_v[ILP];
if (beta != nullptr) {
VecT* beta_val = reinterpret_cast<VecT*>(&beta_v);
*beta_val = *reinterpret_cast<const VecT*>(&beta[threadIdx.x * ILP]);
}
VecT* gamma_val = reinterpret_cast<VecT*>(&gamma_v);
*gamma_val = *reinterpret_cast<const VecT*>(&gamma[threadIdx.x * ILP]);
VecT* output_val = reinterpret_cast<VecT*>(&output_v);
KeyValuePairSum pair_sum;
const hipcub::KeyValuePair<T, T> sum_kv = BlockReduce(temp_storage).Reduce(thread_data, pair_sum);
if (threadIdx.x == 0) {
mu = sum_kv.key;
rsigma = Rsqrt(sum_kv.value - mu * mu + epsilon);
}
__syncthreads();
if (ILP * threadIdx.x < ld) {
#pragma unroll
for (int i = 0; i < ILP; i++) {
output_v[i] = (beta != nullptr)
? gamma_v[i] * (input_v[i] - mu) * rsigma + beta_v[i]
: gamma_v[i] * (input_v[i] - mu) * rsigma;
}
*(reinterpret_cast<VecT*>(&output[idx])) = *output_val;
}
}
} // namespace rocm
} // namespace contrib
} // namespace onnxruntime
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "core/providers/rocm/shared_inc/fpgeneric.h"
#include "core/platform/env_var_utils.h"
#include "contrib_ops/rocm/bert/longformer_global_impl.h"
#include "contrib_ops/rocm/bert/longformer_attention_impl.h"
#include "contrib_ops/rocm/bert/transformer_rocm_common.h"
#include "contrib_ops/rocm/bert/longformer_attention.h"
using namespace onnxruntime::rocm;
using namespace ::onnxruntime::common;
using namespace ONNX_NAMESPACE;
namespace onnxruntime {
namespace contrib {
namespace rocm {
#define REGISTER_KERNEL_TYPED(T) \
ONNX_OPERATOR_TYPED_KERNEL_EX( \
LongformerAttention, \
kMSDomain, \
1, \
T, \
kRocmExecutionProvider, \
(*KernelDefBuilder::Create()) \
.TypeConstraint("T", DataTypeImpl::GetTensorType<T>()), \
LongformerAttention<T>);
REGISTER_KERNEL_TYPED(float)
REGISTER_KERNEL_TYPED(MLFloat16)
template <typename T>
LongformerAttention<T>::LongformerAttention(const OpKernelInfo& info)
: RocmKernel(info), LongformerAttentionBase(info) {
use_compact_memory_ = ParseEnvironmentVariableWithDefault<bool>(longformer::kUseCompactMemory, true);
use_half4_ = ParseEnvironmentVariableWithDefault<bool>(longformer::kUseHalf4, true);
}
template <typename T>
Status LongformerAttention<T>::ComputeInternal(OpKernelContext* context) const {
const Tensor* input = context->Input<Tensor>(0);
const Tensor* weights = context->Input<Tensor>(1);
const Tensor* bias = context->Input<Tensor>(2);
const Tensor* attention_mask = context->Input<Tensor>(3);
const Tensor* global_weights = context->Input<Tensor>(4);
const Tensor* global_bias = context->Input<Tensor>(5);
const Tensor* global_attention_mask = context->Input<Tensor>(6);
ORT_RETURN_IF_ERROR(CheckInputs(input->Shape(), weights->Shape(), bias->Shape(), attention_mask->Shape(),
global_weights->Shape(), global_bias->Shape(), global_attention_mask->Shape()));
// Input shapes:
// input : (batch_size, sequence_length, hidden_size)
// weights : (hidden_size, 3 * hidden_size) -- format 1
// (3, hidden_size, hidden_size) -- format 0
// bias : (3 * hidden_size) -- format 1 (bias for Q, K, V)
// (5 * hidden_size) -- format 0 (bias for Q, K, V, Global_K, Global_V)
// attention_mask : (batch_size, sequence_length)
// global_weights : (hidden_size, 3 * hidden_size) -- format 1
// (3, hidden_size, hidden_size) -- format 0
// global_bias : (3 * hidden_size) -- format 1 (bias for Global_Q, Global_K, Global_V)
// (1 * hidden_size) -- format 0 (bias for Global_Q)
// global_attention_mask : (batch_size, sequence_length)
// Output shapes:
// output : (batch_size, sequence_length, hidden_size)
const auto& shape = input->Shape();
int batch_size = static_cast<int>(shape[0]);
int sequence_length = static_cast<int>(shape[1]);
int hidden_size = static_cast<int>(shape[2]);
int head_size = hidden_size / num_heads_;
Tensor* output = context->Output(0, shape);
rocblas_handle rocblas = RocblasHandle();
hipStream_t stream = Stream();
ROCBLAS_RETURN_IF_ERROR(rocblas_set_stream(rocblas, stream));
constexpr size_t element_size = sizeof(T);
// TODO(tianleiwu): only calculate global index once per model instead of once per LongformerAttention node.
// Build Global Index
auto global_index_buffer = GetScratchBuffer<int>(static_cast<size_t>(batch_size) * sequence_length);
auto batch_global_num_buffer = GetScratchBuffer<int>(batch_size);
size_t global_scratch_bytes = GetGlobalScratchSize(sequence_length);
auto global_scratch_buffer = GetScratchBuffer<void>(global_scratch_bytes);
auto& device_prop = GetDeviceProp();
ORT_RETURN_IF_ERROR(BuildGlobalIndex(
device_prop,
stream,
global_attention_mask->Data<int>(),
batch_size,
sequence_length,
global_index_buffer.get(),
batch_global_num_buffer.get(),
global_scratch_buffer.get(),
global_scratch_bytes));
// Copy batch_global_num to CPU
size_t pinned_buffer_bytes = GetPinnedBufferSize(batch_size);
auto pinned_buffer = AllocateBufferOnCPUPinned<void>(pinned_buffer_bytes);
int* batch_global_num_pinned = reinterpret_cast<int*>(pinned_buffer.get());
HIP_RETURN_IF_ERROR(hipMemcpyAsync(batch_global_num_pinned,
batch_global_num_buffer.get(),
batch_size * sizeof(int),
hipMemcpyDeviceToHost,
stream));
// Create an event to make sure the async copy is finished before reading the data.
AutoDestoryCudaEvent new_event;
hipEvent_t& is_copy_done = new_event.Get();
HIP_RETURN_IF_ERROR(hipEventCreateWithFlags(&is_copy_done, hipEventDisableTiming));
HIP_RETURN_IF_ERROR(hipEventRecord(is_copy_done, stream));
size_t qkv_size = batch_size * sequence_length * 3 * hidden_size * element_size;
// Buffer for GEMM outputs of q, k, v, global_q, global_k and global_v
// TODO(tianleiwu): compact global_q only need batch_size * window * hidden_size * element_size buffer size.
auto gemm_buffer = GetScratchBuffer<void>(qkv_size + qkv_size);
bool use_merged_qkv_weights = (weights->Shape().NumDimensions() == 2);
int m = batch_size * sequence_length;
int n = use_merged_qkv_weights ? 3 * hidden_size : hidden_size;
int k = hidden_size;
typedef typename ToHipType<T>::MappedType HipT;
const HipT* input_data = reinterpret_cast<const HipT*>(input->Data<T>());
const HipT* weights_data = reinterpret_cast<const HipT*>(weights->Data<T>());
const HipT* global_weights_data = reinterpret_cast<const HipT*>(global_weights->Data<T>());
float one = 1.0f;
float zero = 0.0f;
if (use_merged_qkv_weights) {
// Gemm, note that ROCM assumes col-major, so result(N, M) = 1 * weights x input + 0 x B.
ROCBLAS_RETURN_IF_ERROR(rocblasGemmHelper(
rocblas, rocblas_operation_none, rocblas_operation_none, n, m, k, &one,
weights_data, n,
input_data, k,
&zero, reinterpret_cast<HipT*>(gemm_buffer.get()), n, device_prop));
} else {
// q
const HipT* q_weight = weights_data;
HipT* q_data = reinterpret_cast<HipT*>(gemm_buffer.get());
ROCBLAS_RETURN_IF_ERROR(rocblasGemmHelper(
rocblas, rocblas_operation_none, rocblas_operation_none, n, m, k, &one,
q_weight, n,
input_data, k,
&zero, q_data, n, device_prop));
// k
const HipT* k_weight = q_weight + hidden_size * hidden_size;
HipT* k_data = q_data + batch_size * sequence_length * hidden_size;
ROCBLAS_RETURN_IF_ERROR(rocblasGemmHelper(
rocblas, rocblas_operation_none, rocblas_operation_none, n, m, k, &one,
k_weight, n,
input_data, k,
&zero, k_data, n, device_prop));
// v
const HipT* v_weight = k_weight + hidden_size * hidden_size;
HipT* v_data = k_data + batch_size * sequence_length * hidden_size;
ROCBLAS_RETURN_IF_ERROR(rocblasGemmHelper(
rocblas, rocblas_operation_none, rocblas_operation_none, n, m, k, &one,
v_weight, n,
input_data, k,
&zero, v_data, n, device_prop));
}
// Wait for async copy of batch_global_num
HIP_RETURN_IF_ERROR(hipEventSynchronize(is_copy_done));
// Find the maximum number of global tokens in all batches
int max_num_global = 0;
for (int i = 0; i < batch_size; ++i) {
if (max_num_global < batch_global_num_pinned[i]) {
max_num_global = batch_global_num_pinned[i];
}
}
// Do not use compact memory kernel in the following situations:
// (1) global tokens > windows size, compact memory kernel cannot be used due to its assumptions.
// (2) sequence_length == 2 * attention_window, compact memory kernel has parity issue.
// (3) user sets environment variable ORT_LONGFORMER_COMPACT_MEMORY=0
bool disable_compact_memory = (max_num_global > window_ || sequence_length == 2 * window_ || !use_compact_memory_);
// Fully connection for global projection.
// Note that Q only need handle global query tokens if we split GEMM to global Q/K/V separately.
// When there is no global token, need not run global GEMM.
HipT* global_gemm_buffer = nullptr;
if (max_num_global > 0) {
global_gemm_buffer = reinterpret_cast<HipT*>(reinterpret_cast<char*>(gemm_buffer.get()) + qkv_size);
if (use_merged_qkv_weights) {
ROCBLAS_RETURN_IF_ERROR(rocblasGemmHelper(
rocblas, rocblas_operation_none, rocblas_operation_none, n, m, k, &one,
reinterpret_cast<const HipT*>(global_weights->Data<T>()), n,
input_data, k,
&zero, global_gemm_buffer, n, device_prop));
} else {
// global q
const HipT* global_q_weight = global_weights_data;
HipT* global_q = global_gemm_buffer + 2 * batch_size * sequence_length * hidden_size;
if (disable_compact_memory) {
ROCBLAS_RETURN_IF_ERROR(rocblasGemmHelper(
rocblas, rocblas_operation_none, rocblas_operation_none, n, m, k, &one,
global_q_weight, n,
input_data, k,
&zero, global_q, n, device_prop));
} else {
ROCBLAS_RETURN_IF_ERROR(rocblasGemmStridedBatchedHelper(rocblas,
rocblas_operation_none,
rocblas_operation_none,
hidden_size, // m
max_num_global, // n
hidden_size, // k
&one, // alpha
global_q_weight, // A
hidden_size, // lda
0, // strideA
input_data, // B
hidden_size, // ldb
sequence_length * hidden_size, // strideB
&zero, // beta
global_q, // C
hidden_size, // ldc
max_num_global * hidden_size, // strideC
batch_size, // batch count
device_prop));
}
// global k
const HipT* global_k_weight = global_weights_data + hidden_size * hidden_size;
HipT* global_k = global_gemm_buffer;
ROCBLAS_RETURN_IF_ERROR(rocblasGemmHelper(
rocblas, rocblas_operation_none, rocblas_operation_none, n, m, k, &one,
global_k_weight, n,
input_data, k,
&zero, global_k, n, device_prop));
// global v
const HipT* global_v_weight = global_k_weight + hidden_size * hidden_size;
HipT* global_v = global_gemm_buffer + batch_size * sequence_length * hidden_size;
ROCBLAS_RETURN_IF_ERROR(rocblasGemmHelper(
rocblas, rocblas_operation_none, rocblas_operation_none, n, m, k, &one,
global_v_weight, n,
input_data, k,
&zero, global_v, n, device_prop));
}
}
size_t workSpaceSize = GetLongformerAttentionWorkspaceSize(element_size,
batch_size,
num_heads_,
head_size,
sequence_length,
max_num_global,
window_,
disable_compact_memory);
auto workspace_buffer = GetScratchBuffer<void>(workSpaceSize);
ORT_RETURN_IF_ERROR(LaunchLongformerAttentionKernel(
device_prop,
rocblas,
stream,
reinterpret_cast<const HipT*>(gemm_buffer.get()),
reinterpret_cast<const HipT*>(bias->Data<T>()),
reinterpret_cast<const HipT*>(attention_mask->Data<T>()),
reinterpret_cast<const HipT*>(global_gemm_buffer),
reinterpret_cast<const HipT*>(global_bias->Data<T>()),
global_attention_mask->Data<int>(),
global_index_buffer.get(),
batch_global_num_buffer.get(),
pinned_buffer.get(),
workspace_buffer.get(),
output->MutableData<T>(),
batch_size,
sequence_length,
num_heads_,
head_size,
window_,
max_num_global,
element_size,
disable_compact_memory,
use_merged_qkv_weights,
use_half4_));
// Defer release of pinned memory since hipStreamSynchronize is not used here and kernel need access the buffer.
this->AddDeferredReleaseCPUPtr(pinned_buffer.release());
return Status::OK();
}
} // namespace rocm
} // namespace contrib
} // namespace onnxruntime
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include "core/common/common.h"
#include "core/providers/rocm/rocm_kernel.h"
#include "contrib_ops/cpu/bert/longformer_attention_base.h"
namespace onnxruntime {
namespace contrib {
namespace rocm {
using namespace onnxruntime::rocm;
template <typename T>
class LongformerAttention final : public RocmKernel, public LongformerAttentionBase {
public:
LongformerAttention(const OpKernelInfo& info);
Status ComputeInternal(OpKernelContext* context) const override;
private:
bool use_compact_memory_;
bool use_half4_;
};
} // namespace rocm
} // namespace contrib
} // namespace onnxruntime
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include "core/providers/rocm/shared_inc/rocm_utils.h"
namespace onnxruntime {
namespace contrib {
namespace rocm {
size_t GetPinnedBufferSize(
size_t batch_size);
size_t GetLongformerAttentionWorkspaceSize(
size_t element_size,
size_t batch_size,
size_t num_heads,
size_t head_size,
size_t sequence_length,
size_t max_num_global,
size_t window,
bool disable_compact_memory);
Status LaunchLongformerAttentionKernel(
const hipDeviceProp_t& device_prop, // Device Properties
rocblas_handle rocblas, // Rocblas handle
hipStream_t stream, // ROCM stream
const void* input, // Input tensor
const void* bias, // Bias tensor
const void* attention_mask, // Attention mask with shape (B, S)
const void* global_input, // Global attention input, or nullptr when max_num_global == 0.
const void* global_bias, // Global bias tensor
const int* global_attention, // Global attention flags with shape (B, S)
const int* global_index, // Global index
const int* batch_global_num, // Number of global tokens per batch. It is in device memory.
void* pinned_buffer, // Pinned memory: copy of batch_global_num, and a buffer to copy to scratch2.
void* workspace, // Temporary buffer
void* output, // Output tensor
int batch_size, // Batch size (B)
int sequence_length, // Sequence length (S)
int num_heads, // Number of attention heads (N)
int head_size, // Hidden layer size per head (H)
int window, // One sided attention window (W)
int max_num_global, // Maximum number of global tokens (G)
const size_t element_size, // Element size of input tensor,
bool disable_compact_memory, // Disable compact memory kernel
bool use_merged_qkv_weights,
bool use_half4);
} // namespace rocm
} // namespace contrib
} // namespace onnxruntime
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment