"...lm-evaluation-harness.git" did not exist on "f1e62d3650542c13b3e1fd9fbee4a4c0757c4f4d"
Commit 1a91fcc2 authored by gaoqiong's avatar gaoqiong
Browse files

add dtk所需文件

parent a144865d
Pipeline #492 failed with stages
in 0 seconds
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "activations.h"
#include "core/framework/op_kernel.h"
using namespace onnxruntime::rocm;
namespace onnxruntime {
namespace contrib {
namespace rocm {
#define REGISTER_ACTIVATION_KERNEL(x, ver, domain, T) \
ONNX_OPERATOR_TYPED_KERNEL_EX( \
x, \
domain, \
ver, \
T, \
kRocmExecutionProvider, \
(*KernelDefBuilder::Create()) \
.TypeConstraint("T", DataTypeImpl::GetTensorType<T>()) \
.MayInplace(0, 0), \
x<T>);
#define UNARY_ACTIVATION_COMPUTE(x, T) \
template <> \
Status x<T>::ComputeInternal(OpKernelContext* context) const { \
UnaryElementwisePreparation p; \
ORT_RETURN_IF_ERROR(UnaryElementwise::Prepare(context, &p)); \
Ctx##x func_ctx = MakeFuncCtx(); \
Impl_##x<typename ToHipType<T>::MappedType>( \
Stream(), \
reinterpret_cast<const typename ToHipType<T>::MappedType*>(p.input_tensor->Data<T>()), \
reinterpret_cast<typename ToHipType<T>::MappedType*>(p.output_tensor->MutableData<T>()), \
&func_ctx, p.output_tensor->Shape().Size()); \
\
return Status::OK(); \
}
#define UNARY_ACTIVATION_OP_TYPED(name, ver, domain, T) \
REGISTER_ACTIVATION_KERNEL(name, ver, domain, T) \
UNARY_ACTIVATION_COMPUTE(name, T)
#define UNARY_ACTIVATION_OP_HFD(name, ver, domain) \
UNARY_ACTIVATION_OP_TYPED(name, ver, domain, MLFloat16) \
UNARY_ACTIVATION_OP_TYPED(name, ver, domain, float) \
UNARY_ACTIVATION_OP_TYPED(name, ver, domain, double)
UNARY_ACTIVATION_OP_HFD(Affine, 1, kOnnxDomain);
UNARY_ACTIVATION_OP_HFD(ParametricSoftplus, 1, kOnnxDomain);
UNARY_ACTIVATION_OP_HFD(ScaledTanh, 1, kOnnxDomain);
UNARY_ACTIVATION_OP_HFD(Gelu, 1, kMSDomain);
UNARY_ACTIVATION_OP_HFD(QuickGelu, 1, kMSDomain);
REGISTER_ACTIVATION_KERNEL(ThresholdedRelu, 1, kOnnxDomain, MLFloat16)
REGISTER_ACTIVATION_KERNEL(ThresholdedRelu, 1, kOnnxDomain, float)
REGISTER_ACTIVATION_KERNEL(ThresholdedRelu, 1, kOnnxDomain, double)
} //namespace rocm
} // namespace contrib
} // namespace onnxruntime
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include "core/providers/rocm/rocm_common.h"
#include "core/providers/rocm/math/unary_elementwise_ops.h"
#include "core/providers/rocm/math/binary_elementwise_ops.h"
#include "core/providers/rocm/activation/activations.h"
#include "activations_impl.h"
using namespace onnxruntime::rocm;
namespace onnxruntime {
namespace contrib {
namespace rocm {
template <typename T>
class Affine final : public UnaryElementwise {
public:
Affine(const OpKernelInfo& info) : UnaryElementwise(info) {
ORT_ENFORCE(info.GetAttr("alpha", &alpha_).IsOK());
ORT_ENFORCE(info.GetAttr("beta", &beta_).IsOK());
}
Status ComputeInternal(OpKernelContext* context) const override;
private:
MAKE_FUNC_CTX_ALPHA_BETA()
float alpha_;
float beta_;
};
template <typename T>
class ParametricSoftplus final : public UnaryElementwise {
public:
ParametricSoftplus(const OpKernelInfo& info) : UnaryElementwise(info) {
ORT_ENFORCE(info.GetAttr("alpha", &alpha_).IsOK());
ORT_ENFORCE(info.GetAttr("beta", &beta_).IsOK());
}
Status ComputeInternal(OpKernelContext* context) const override;
private:
MAKE_FUNC_CTX_ALPHA_BETA()
float alpha_;
float beta_;
};
template <typename T>
class ScaledTanh final : public UnaryElementwise {
public:
ScaledTanh(const OpKernelInfo& info) : UnaryElementwise(info) {
ORT_ENFORCE(info.GetAttr("alpha", &alpha_).IsOK());
ORT_ENFORCE(info.GetAttr("beta", &beta_).IsOK());
}
Status ComputeInternal(OpKernelContext* context) const override;
private:
MAKE_FUNC_CTX_ALPHA_BETA()
float alpha_;
float beta_;
};
template <typename T>
class Gelu final : public UnaryElementwise {
public:
Gelu(const OpKernelInfo& info) : UnaryElementwise(info) {}
Status ComputeInternal(OpKernelContext* context) const override;
private:
MAKE_FUNC_CTX_NULL()
};
template <typename T>
class QuickGelu final : public UnaryElementwise {
public:
QuickGelu(const OpKernelInfo& info) : UnaryElementwise(info) {
alpha_ = info.GetAttrOrDefault<float>("alpha", 1.702f);
}
Status ComputeInternal(OpKernelContext* context) const override;
private:
MAKE_FUNC_CTX_ALPHA()
float alpha_;
};
} // namespace rocm
} // namespace contrib
} // namespace onnxruntime
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include <hip/hip_runtime.h>
#include "activations_impl.h"
#include "core/providers/rocm/cu_inc/common.cuh"
#include "core/providers/rocm/cu_inc/unary_elementwise_impl.cuh"
using namespace onnxruntime::rocm;
namespace onnxruntime {
namespace contrib {
namespace rocm {
template <typename T>
struct OP_Affine : public CtxAffine {
__device__ __inline__ T operator()(const T& a) const {
return a * (T)alpha + (T)beta;
}
};
template <typename T>
struct OP_ParametricSoftplus : public CtxParametricSoftplus {
__device__ __inline__ T operator()(const T& a) const {
if (a > (T)0)
return (T)alpha * (a * (T)beta + _Log(_Exp(-a * (T)beta) + (T)1));
else
return (T)alpha * _Log(_Exp(a * (T)beta) + (T)1);
}
};
template <typename T>
struct OP_ScaledTanh : public CtxScaledTanh {
__device__ __inline__ T operator()(const T& a) const {
return (T)alpha * _Tanh(a * (T)beta);
}
};
template <typename T>
struct OP_Gelu : public CtxGelu {
__device__ __inline__ T operator()(const T& a) const {
return _Gelu(a);
}
};
template <>
struct OP_Gelu<half> : public CtxGelu {
__device__ __inline__ half operator()(const half& a) const {
return static_cast<half>(_Gelu(static_cast<float>(a)));
}
};
template <typename T>
struct OP_QuickGelu : public CtxQuickGelu {
__device__ __inline__ T operator()(const T& a) const {
T v = a * static_cast<T>(alpha);
T one = static_cast<T>(1.f);
T zero = static_cast<T>(0.f);
T sigmoid = v >= zero ? one / (one + _Exp(-v)) : one - one / (one + _Exp(v));
return a * sigmoid;
}
};
#define UNARY_ACTIVATION_IMPL(name) \
UNARY_ACTIVATION_IMPL_DECLARATION(name) { \
UnaryElementWiseImpl(stream, \
input_data, \
output_data, \
*reinterpret_cast<const OP_##name<T>*>(func_ctx), \
count); \
}
#define SPECIALIZED_UNARY_ACTIVATION_IMPL(name, T) \
template void Impl_##name<T>(hipStream_t stream, const T* input_data, T* output_data, const Ctx##name* func_ctx, size_t count);
#define SPECIALIZED_UNARY_ACTIVATIONL_HFD(name) \
SPECIALIZED_UNARY_ACTIVATION_IMPL(name, half) \
SPECIALIZED_UNARY_ACTIVATION_IMPL(name, float) \
SPECIALIZED_UNARY_ACTIVATION_IMPL(name, double)
#define UNARY_ACTIVATION_OP_NAME(name) \
UNARY_ACTIVATION_IMPL(name); \
SPECIALIZED_UNARY_ACTIVATIONL_HFD(name)
UNARY_CONTRIB_ACTIVATION_OPS()
#undef UNARY_ACTIVATION_OP_NAME
} // namespace rocm
} // namespace contrib
} // namespace onnxruntime
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include "core/providers/rocm/activation/activations_impl.h"
namespace onnxruntime {
namespace contrib {
namespace rocm {
typedef onnxruntime::rocm::CtxAlphaBeta CtxAffine;
typedef onnxruntime::rocm::CtxAlphaBeta CtxParametricSoftplus;
typedef onnxruntime::rocm::CtxAlphaBeta CtxScaledTanh;
typedef onnxruntime::rocm::CtxNull CtxGelu;
typedef onnxruntime::rocm::CtxAlpha CtxQuickGelu;
#define UNARY_CONTRIB_ACTIVATION_OPS() \
UNARY_ACTIVATION_OP_NAME(ScaledTanh) \
UNARY_ACTIVATION_OP_NAME(Affine) \
UNARY_ACTIVATION_OP_NAME(ParametricSoftplus) \
UNARY_ACTIVATION_OP_NAME(Gelu) \
UNARY_ACTIVATION_OP_NAME(QuickGelu)
#define UNARY_ACTIVATION_OP_NAME(name) UNARY_ACTIVATION_IMPL_DECLARATION(name);
UNARY_CONTRIB_ACTIVATION_OPS()
#undef UNARY_ACTIVATION_OP_NAME
} // namespace rocm
} // namespace contrib
} // namespace onnxruntime
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "core/providers/shared_library/provider_api.h"
#include "contrib_ops/cpu/aten_ops/aten_op.h"
namespace onnxruntime {
namespace contrib {
namespace rocm {
ONNX_OPERATOR_KERNEL_EX(
ATen, kPytorchAtenDomain, 1, kRocmExecutionProvider,
(*KernelDefBuilder::Create()).TypeConstraint("T", DataTypeImpl::AllTensorAndSequenceTensorTypes()),
onnxruntime::contrib::ATen);
} // namespace rocm
} // namespace contrib
} // namespace onnxruntime
#include "hip/hip_runtime.h"
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "core/providers/rocm/rocm_common.h"
#include "core/providers/rocm/cu_inc/common.cuh"
#include "contrib_ops/rocm/bert/add_bias_transpose.h"
namespace onnxruntime {
namespace rocm {
struct __align__(8) Half4 {
half2 x;
half2 y;
};
__device__ __forceinline__ Half4 operator+(const Half4& a, const Half4& b) {
Half4 r;
r.x = a.x + b.x;
r.y = a.y + b.y;
return r;
}
__device__ __forceinline__ float2 operator+(const float2& a, const float2& b) {
return make_float2(a.x + b.x, a.y + b.y);
}
__device__ __forceinline__ float4 operator+(const float4& a, const float4& b) {
return make_float4(a.x + b.x, a.y + b.y, a.z + b.z, a.w + b.w);
}
} // namespace rocm
} // namespace onnxruntime
using namespace onnxruntime::rocm;
namespace onnxruntime {
namespace contrib {
namespace rocm {
template <typename T>
__global__ void AddBiasTransposeTrt(const T* input, const T* biases, T* output) {
// Input: BxSxMxNxH (Format 2)
// Output: BxSxNxMxH
// B is batch_size, S is sequence_length, M is number of matrices, N is num_heads, H is head_size
int n = threadIdx.y;
int s = blockIdx.x;
int b = blockIdx.y;
int m = blockIdx.z; // matrix id
const int H = blockDim.x;
const int N = blockDim.y;
const int S = gridDim.x;
const int M = gridDim.z;
const int NH = N * H;
const int offset = (b * S + s) * M * NH;
const int in_offset = offset + m * NH + n * H;
const int out_offset = offset + (n * M + m) * H;
const int h = threadIdx.x;
if (h < H) {
output[out_offset + h] = input[in_offset + h] + biases[m * NH + n * H + h];
}
}
template <typename T>
__global__ void AddBiasTransposeTrtLarge(const int head_size, const T* input, const T* biases, T* output) {
int n = threadIdx.y;
int s = blockIdx.x;
int b = blockIdx.y;
int m = blockIdx.z;
const int stride = blockDim.x;
const int H = head_size;
const int N = blockDim.y;
const int S = gridDim.x;
const int M = gridDim.z;
const int NH = N * H;
const int offset = (b * S + s) * M * NH;
const int in_offset = offset + m * NH + n * H;
const int out_offset = offset + (n * M + m) * H;
int h = threadIdx.x;
while (h < H) {
output[out_offset + h] = input[in_offset + h] + biases[m * NH + n * H + h];
h += stride;
}
}
template <typename T>
__global__ void AddBiasTransposeTrt(const T* query, const T* key, const T* value, const T* biases, T* output) {
// Q: BxSxNxH
// K: BxSxNxH
// V: BxSxNxH
// Output: BxSxNxMxH
// B is batch_size, S is sequence_length, M is number of matrices (3), N is num_heads, H is head_size
int n = threadIdx.y;
int s = blockIdx.x;
int b = blockIdx.y;
int m = blockIdx.z; // matrix id
const int H = blockDim.x;
const int N = blockDim.y;
const int S = gridDim.x;
const int M = gridDim.z;
const T* input = (m == 0 ? query : (m == 1 ? key : value));
const int NH = N * H;
const int in_offset = (b * S + s) * NH + n * H;
const int out_offset = (b * S + s) * M * NH + (n * M + m) * H;
const int h = threadIdx.x;
if (h < H) {
output[out_offset + h] = input[in_offset + h] + biases[m * NH + n * H + h];
}
}
template <typename T>
__global__ void AddBiasTransposeTrtLarge(const int head_size,
const T* query, const T* key, const T* value, const T* biases, T* output) {
int n = threadIdx.y;
int s = blockIdx.x;
int b = blockIdx.y;
int m = blockIdx.z; // matrix id
const int stride = blockDim.x;
const int H = head_size;
const int N = blockDim.y;
const int S = gridDim.x;
const int M = gridDim.z;
const T* input = (m == 0 ? query : (m == 1 ? key : value));
const int NH = N * H;
const int in_offset = (b * S + s) * NH + n * H;
const int out_offset = (b * S + s) * M * NH + (n * M + m) * H;
int h = threadIdx.x;
if (h < H) {
output[out_offset + h] = input[in_offset + h] + biases[m * NH + n * H + h];
h += stride;
}
}
template <typename T>
__global__ void AddBiasTransposeQKV(const T* input, const T* biases, T* output) {
// Input: BxSxMxNxH (Format 1)
// Output: MxBxNxSxH
// B is batch_size, S is sequence_length, M is number of matrices, N is num_heads, H is head_size
int n = threadIdx.y;
int s = blockIdx.x;
int b = blockIdx.y;
int m = blockIdx.z; // matrix id
const int head_size = blockDim.x;
const int num_heads = blockDim.y;
const int sequence_length = gridDim.x;
const int batch_size = gridDim.y;
const int M = gridDim.z;
const int H = head_size;
const int NH = num_heads * head_size;
const int NHS = NH * sequence_length;
int in_offset = n * head_size + (m + s * M) * NH + b * NHS * M;
const int out_offset = s * head_size + n * sequence_length * H + b * NHS + m * NHS * batch_size;
const int h = threadIdx.x;
if (h < head_size) {
output[out_offset + h] = input[in_offset + h] + biases[m * NH + n * H + h];
}
}
template <typename T>
__global__ void AddBiasTransposeQKV(const T* input, const T* biases, T* output, int v_head_size) {
// Input: BxSxMxNxH (Format 1)
// Output: MxBxNxSxH
// B is batch_size, S is sequence_length, M is number of matrices, N is num_heads, H is head_size
int n = threadIdx.y; // head_num_id
int s = blockIdx.x; // sequence_id
int b = blockIdx.y; // batch_id
int m = blockIdx.z; // matrix id (Q=0, K=1, V=2)
const int h = threadIdx.x; // head_element_id
const int qk_head_size = blockDim.x;
const int num_heads = blockDim.y;
const int sequence_length = gridDim.x;
const int batch_size = gridDim.y;
const int head_size = (m == 2 ? v_head_size : qk_head_size);
const int total_head_size = num_heads * (qk_head_size + qk_head_size + v_head_size);
int in_offset;
int out_offset;
int bias_offset;
in_offset = b * (total_head_size * sequence_length) + // B
s * (total_head_size) + // S
m * (qk_head_size * num_heads) + // M
n * head_size + // N
h; // H
out_offset = m * (num_heads * qk_head_size * sequence_length * batch_size) + // M
b * (num_heads * head_size * sequence_length) + // B
n * (sequence_length * head_size) + // N
s * (head_size) + // S
h; // H
bias_offset = m * (num_heads * qk_head_size) + // M
n * (head_size) + // N
h; // H
if (h < head_size) {
output[out_offset] = input[in_offset] + biases[bias_offset];
}
}
template <typename T>
__global__ void AddBiasTransposeQKVLarge(const int head_size, const T* input, const T* biases, T* output) {
int n = threadIdx.y;
int s = blockIdx.x;
int b = blockIdx.y;
int m = blockIdx.z;
const int stride = blockDim.x;
const int num_heads = blockDim.y;
const int sequence_length = gridDim.x;
const int batch_size = gridDim.y;
const int M = gridDim.z;
const int H = head_size;
const int NH = num_heads * H;
const int NHS = NH * sequence_length;
int in_offset = n * H + (m + s * M) * NH + b * NHS * M;
const int out_offset = s * H + n * sequence_length * H + b * NHS + m * NHS * batch_size;
int h = threadIdx.x;
while (h < H) {
output[out_offset + h] = input[in_offset + h] + biases[m * NH + n * H + h];
h += stride;
}
}
template <typename T>
__global__ void AddBiasTranspose(const T* input, const T* biases, T* output) {
// Input: MxBxSxNxH (Format 0)
// Output: MxBxNxSxH
// B is batch_size, S is sequence_length, M is number of matrices, N is num_heads, H is head_size
int n = threadIdx.y;
int s = blockIdx.x;
int b = blockIdx.y;
int m = blockIdx.z;
const int head_size = blockDim.x;
const int num_heads = blockDim.y;
const int sequence_length = gridDim.x;
const int batch_size = gridDim.y;
const int H = head_size;
const int NH = num_heads * head_size;
const int NHS = NH * sequence_length;
int in_offset = n * H + s * NH + (b + m * batch_size) * NHS;
const int out_offset = s * H + n * sequence_length * H + (b + m * batch_size) * NHS;
const int h = threadIdx.x;
if (h < head_size) {
output[out_offset + h] = input[in_offset + h] + biases[m * NH + n * H + h];
}
}
template <typename T>
__global__ void AddBiasTransposeLarge(const int head_size, const T* input, const T* biases, T* output) {
int n = threadIdx.y;
int s = blockIdx.x;
int b = blockIdx.y;
int m = blockIdx.z;
const int stride = blockDim.x;
const int num_heads = blockDim.y;
const int sequence_length = gridDim.x;
const int batch_size = gridDim.y;
const int H = head_size;
const int NH = num_heads * H;
const int NHS = NH * sequence_length;
int in_offset = n * H + s * NH + (b + m * batch_size) * NHS;
const int out_offset = (s + n * sequence_length) * H + (b + m * batch_size) * NHS;
int h = threadIdx.x;
while (h < H) {
output[out_offset + h] = input[in_offset + h] + biases[m * NH + n * H + h];
h += stride;
}
}
template <typename T>
void InvokeAddBiasTranspose(
hipStream_t stream, const int num_matrices, const int format, const int max_threads_per_block,
const int batch_size, const int sequence_length, const int num_heads, const int qk_head_size,
const T* input, const T* biases, T* output, const int v_head_size) {
const dim3 grid(sequence_length, batch_size, num_matrices);
if (qk_head_size * num_heads <= max_threads_per_block) {
const dim3 block(qk_head_size, num_heads, 1);
if (format == 2) {
hipLaunchKernelGGL(HIP_KERNEL_NAME(AddBiasTransposeTrt<T>), grid, block, 0, stream, input, biases, output);
} else if (format == 1) {
if (v_head_size == -1 || qk_head_size == v_head_size) {
hipLaunchKernelGGL(HIP_KERNEL_NAME(AddBiasTransposeQKV<T>), grid, block, 0, stream, input, biases, output);
} else {
hipLaunchKernelGGL(HIP_KERNEL_NAME(AddBiasTransposeQKV<T>), grid, block, 0, stream, input, biases, output, v_head_size);
}
} else {
hipLaunchKernelGGL(HIP_KERNEL_NAME(AddBiasTranspose<T>), grid, block, 0, stream, input, biases, output);
}
} else {
const dim3 block(CeilDiv(max_threads_per_block, num_heads), num_heads, 1);
if (format == 2) {
hipLaunchKernelGGL(HIP_KERNEL_NAME(AddBiasTransposeTrtLarge<T>), grid, block, 0, stream, qk_head_size, input, biases, output);
} else if (format == 1) {
if (v_head_size == -1 || qk_head_size == v_head_size) {
hipLaunchKernelGGL(HIP_KERNEL_NAME(AddBiasTransposeQKVLarge<T>), grid, block, 0, stream, qk_head_size, input, biases, output);
} else {
ORT_THROW("AddBiasTranspose (format 1) not implemented for hidden_size > max_threads_per_block");
}
} else {
hipLaunchKernelGGL(HIP_KERNEL_NAME(AddBiasTransposeLarge<T>), grid, block, 0, stream, qk_head_size, input, biases, output);
}
}
}
template <>
void LaunchAddBiasTranspose(
hipStream_t stream, const int num_matrices, const int format, const int max_threads_per_block,
const int batch_size, const int sequence_length, const int num_heads, const int qk_head_size,
const half* input, const half* biases, half* output,
bool enable_half4, const int v_head_size) {
if (enable_half4 && 0 == (qk_head_size % 4) && 0 == (v_head_size % 4)) {
const int H = qk_head_size / 4;
const int H_v = v_head_size / 4;
const Half4* input2 = reinterpret_cast<const Half4*>(input);
const Half4* biases2 = reinterpret_cast<const Half4*>(biases);
Half4* output2 = reinterpret_cast<Half4*>(output);
InvokeAddBiasTranspose<Half4>(stream, num_matrices, format, max_threads_per_block,
batch_size, sequence_length, num_heads, H, input2, biases2, output2, H_v);
} else if (0 == (qk_head_size & 1) && 0 == (v_head_size % 1)) {
const int H = qk_head_size / 2;
const int H_v = v_head_size / 2;
const half2* input2 = reinterpret_cast<const half2*>(input);
const half2* biases2 = reinterpret_cast<const half2*>(biases);
half2* output2 = reinterpret_cast<half2*>(output);
InvokeAddBiasTranspose<half2>(stream, num_matrices, format, max_threads_per_block,
batch_size, sequence_length, num_heads, H, input2, biases2, output2, H_v);
} else {
InvokeAddBiasTranspose<half>(
stream, num_matrices, format, max_threads_per_block,
batch_size, sequence_length, num_heads, qk_head_size, input, biases, output, v_head_size);
}
}
template <>
void LaunchAddBiasTranspose(
hipStream_t stream, const int num_matrices, const int format, const int max_threads_per_block,
const int batch_size, const int sequence_length, const int num_heads, const int qk_head_size,
const float* input, const float* biases, float* output,
bool /*enable_half4*/, const int v_head_size) {
if (0 == (qk_head_size % 4)) {
const int H = qk_head_size / 4;
const float4* input2 = reinterpret_cast<const float4*>(input);
const float4* biases2 = reinterpret_cast<const float4*>(biases);
float4* output2 = reinterpret_cast<float4*>(output);
InvokeAddBiasTranspose<float4>(
stream, num_matrices, format, max_threads_per_block,
batch_size, sequence_length, num_heads, H, input2, biases2, output2, v_head_size / 4);
} else if (0 == (qk_head_size & 1)) {
const int H = qk_head_size / 2;
const float2* input2 = reinterpret_cast<const float2*>(input);
const float2* biases2 = reinterpret_cast<const float2*>(biases);
float2* output2 = reinterpret_cast<float2*>(output);
InvokeAddBiasTranspose<float2>(
stream, num_matrices, format, max_threads_per_block,
batch_size, sequence_length, num_heads, H, input2, biases2, output2, v_head_size / 2);
} else {
InvokeAddBiasTranspose<float>(
stream, num_matrices, format, max_threads_per_block,
batch_size, sequence_length, num_heads, qk_head_size, input, biases, output, v_head_size);
}
}
template <typename T>
void InvokeAddBiasTransposeTrt(
hipStream_t stream, const int max_threads_per_block,
const int batch_size, const int sequence_length, const int num_heads, const int head_size,
const T* biases, const T* query, const T* key, const T* value, T* output) {
constexpr int num_matrices = 3;
const dim3 grid(sequence_length, batch_size, num_matrices);
if (head_size * num_heads <= max_threads_per_block) {
const dim3 block(head_size, num_heads, 1);
hipLaunchKernelGGL(HIP_KERNEL_NAME(AddBiasTransposeTrt<T>), grid, block, 0, stream, query, key, value, biases, output);
} else {
const dim3 block(CeilDiv(max_threads_per_block, num_heads), num_heads, 1);
hipLaunchKernelGGL(HIP_KERNEL_NAME(AddBiasTransposeTrtLarge<T>), grid, block, 0, stream, head_size, query, key, value, biases, output);
}
}
template <>
void LaunchAddBiasTransposeTrt(
hipStream_t stream, const int max_threads_per_block,
const int batch_size, const int sequence_length,
const int num_heads, const int head_size,
const float* biases, const float* query, const float* key, const float* value, float* output) {
ORT_ENFORCE(false, "Shall not call this since fused kernel does not support float input.");
}
template <>
void LaunchAddBiasTransposeTrt(
hipStream_t stream, const int max_threads_per_block,
const int batch_size, const int sequence_length,
const int num_heads, const int head_size,
const half* biases, const half* query, const half* key, const half* value, half* output) {
if (0 == (head_size % 4)) {
const int H = head_size / 4;
const Half4* query2 = reinterpret_cast<const Half4*>(query);
const Half4* key2 = reinterpret_cast<const Half4*>(key);
const Half4* value2 = reinterpret_cast<const Half4*>(value);
const Half4* biases2 = reinterpret_cast<const Half4*>(biases);
Half4* output2 = reinterpret_cast<Half4*>(output);
InvokeAddBiasTransposeTrt<Half4>(stream, max_threads_per_block,
batch_size, sequence_length, num_heads, H,
biases2, query2, key2, value2, output2);
} else if (0 == (head_size & 1)) {
const int H = head_size / 2;
const half2* query2 = reinterpret_cast<const half2*>(query);
const half2* key2 = reinterpret_cast<const half2*>(key);
const half2* value2 = reinterpret_cast<const half2*>(value);
const half2* biases2 = reinterpret_cast<const half2*>(biases);
half2* output2 = reinterpret_cast<half2*>(output);
InvokeAddBiasTransposeTrt<half2>(stream, max_threads_per_block,
batch_size, sequence_length, num_heads, H,
biases2, query2, key2, value2, output2);
} else {
InvokeAddBiasTransposeTrt<half>(stream, max_threads_per_block,
batch_size, sequence_length, num_heads, head_size,
biases, query, key, value, output);
}
}
} // namespace rocm
} // namespace contrib
} // namespace onnxruntime
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include "core/providers/rocm/shared_inc/rocm_utils.h"
namespace onnxruntime {
namespace contrib {
namespace rocm {
// Fused kernel of Add (bias) and Transpose.
// Shape of inputs and outputs:
// biases: (num_matrices, num_heads * head_size)
// format 0:
// input: (num_matrices, batch_size, sequence_length, num_heads, head_size)
// output: (num_matrices, batch_size, num_heads, sequence_length, head_size)
// format 1:
// input : (batch_size, sequence_length, num_matrices, num_heads, head_size)
// output: (num_matrices, batch_size, num_heads, sequence_length, head_size)
// format 2:
// input : (batch_size, sequence_length, num_matrices, num_heads, head_size)
// output: (batch_size, sequence_length, num_heads, num_matrices, head_size)
template <typename T>
void LaunchAddBiasTranspose(
hipStream_t stream, const int num_matrices, const int format, const int max_threads_per_block,
const int batch_size, const int sequence_length, const int num_heads, const int qk_head_size,
const T* input, const T* biases, T* output, bool enable_half4, const int v_head_size);
// Add (bias) and Transpose for separated inputs of Q, K and V, and output Trt format.
// output: (batch_size, sequence_length, num_heads, num_matrices, head_size)
// It assumes sequence_length == kv_sequence_length and head_size == v_head_size.
template <typename T>
void LaunchAddBiasTransposeTrt(
hipStream_t stream, const int max_threads_per_block,
const int batch_size, const int sequence_length,
const int num_heads, const int head_size,
const T* biases, const T* query, const T* key, const T* value, T* output);
} // namespace rocm
} // namespace contrib
} // namespace onnxruntime
#include "hip/hip_runtime.h"
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "core/providers/rocm/rocm_common.h"
#include "contrib_ops/rocm/bert/attention_impl.h"
using namespace onnxruntime::rocm;
namespace onnxruntime {
namespace contrib {
namespace rocm {
template <typename T>
__global__ void ConcatTensorToTensor(const int tensor_add_sequence_length,
const T* tensor_in,
const T* tensor_add,
T* tensor_out) {
const int h = threadIdx.x;
const int n = threadIdx.y;
const int s = blockIdx.x;
const int b = blockIdx.y;
const int chunk_id = blockIdx.z;
const int all_sequence_length = gridDim.x;
const int batch_size = gridDim.y;
const int num_heads = blockDim.y;
const int H = blockDim.x;
// K: number of identical tensors
// tensor_in: K x BxNxPxH
// tensor_add: K x BxNxLxH
// tensor_out: K x BxNxTxH, where T = P + L
const int tensor_in_sequence_length = all_sequence_length - tensor_add_sequence_length;
const int present_SH = all_sequence_length * H;
const int present_NSH = num_heads * present_SH;
int out_offset = b * present_NSH + n * present_SH + s * H + h + chunk_id * (present_NSH * batch_size);
if (s < tensor_in_sequence_length) {
const int past_SH = tensor_in_sequence_length * H;
const int past_NSH = num_heads * past_SH;
const int in_offset = b * past_NSH + n * past_SH + s * H + h + chunk_id * (past_NSH * batch_size);
tensor_out[out_offset] = tensor_in[in_offset];
} else if (s < all_sequence_length) {
const int SH = tensor_add_sequence_length * H;
const int NSH = num_heads * SH;
const int in_offset = b * NSH + n * SH + (s - tensor_in_sequence_length) * H + h + chunk_id * (NSH * batch_size);
tensor_out[out_offset] = tensor_add[in_offset];
}
}
template <typename T>
__global__ void ConcatTensorToTensorLarge(const int tensor_add_sequence_length,
const int H,
const T* tensor_in,
const T* tensor_add,
T* tensor_out) {
// Use when (H*)*num_heads > 1024
int h = threadIdx.x;
const int n = threadIdx.y;
const int s = blockIdx.x;
const int b = blockIdx.y;
const int chunk_id = blockIdx.z;
const int all_sequence_length = gridDim.x;
const int batch_size = gridDim.y;
const int num_heads = blockDim.y;
const int stride = blockDim.x;
// K: number of identical tensor
// tensor_in: K x BxNxPxH
// tensor_add: K x BxNxLxH
// tensor_out: K x BxNxTxH
const int tensor_in_sequence_length = all_sequence_length - tensor_add_sequence_length;
const int present_SH = all_sequence_length * H;
const int present_NSH = num_heads * present_SH;
while (h < H) {
int out_offset = b * present_NSH + n * present_SH + s * H + h + chunk_id * (present_NSH * batch_size);
if (s < tensor_in_sequence_length) {
const int past_SH = tensor_in_sequence_length * H;
const int past_NSH = num_heads * past_SH;
const int in_offset = b * past_NSH + n * past_SH + s * H + h + chunk_id * (past_NSH * batch_size);
tensor_out[out_offset] = tensor_in[in_offset];
} else if (s < all_sequence_length) {
const int SH = tensor_add_sequence_length * H;
const int NSH = num_heads * SH;
const int in_offset = b * NSH + n * SH + (s - tensor_in_sequence_length) * H + h + chunk_id * (NSH * batch_size);
tensor_out[out_offset] = tensor_add[in_offset];
}
h += stride;
}
}
Status LaunchConcatTensorToTensor(hipStream_t stream,
const int all_sequence_length,
const int sequence_length,
const int batch_size,
const int head_size,
const int num_heads,
const int max_threads_per_block,
const int matrix_num,
const float* tensor_in,
const float* tensor_add,
float* tensor_out) {
const dim3 grid(all_sequence_length, batch_size, matrix_num);
if (0 == (head_size & 1)) {
const int H = head_size / 2;
if (H * num_heads <= max_threads_per_block) {
const dim3 block(H, num_heads, 1);
hipLaunchKernelGGL(HIP_KERNEL_NAME(ConcatTensorToTensor<float2>), grid, block, 0, stream, sequence_length,
reinterpret_cast<const float2*>(tensor_in),
reinterpret_cast<const float2*>(tensor_add),
reinterpret_cast<float2*>(tensor_out));
} else {
const dim3 block(max_threads_per_block / num_heads, num_heads, 1);
hipLaunchKernelGGL(HIP_KERNEL_NAME(ConcatTensorToTensorLarge<float2>), grid, block, 0, stream, sequence_length,
H,
reinterpret_cast<const float2*>(tensor_in),
reinterpret_cast<const float2*>(tensor_add),
reinterpret_cast<float2*>(tensor_out));
}
} else {
if (head_size * num_heads <= max_threads_per_block) {
const dim3 block(head_size, num_heads, 1);
hipLaunchKernelGGL(HIP_KERNEL_NAME(ConcatTensorToTensor<float>), grid, block, 0, stream, sequence_length, tensor_in, tensor_add, tensor_out);
} else {
const dim3 block(max_threads_per_block / num_heads, num_heads, 1);
hipLaunchKernelGGL(HIP_KERNEL_NAME(ConcatTensorToTensorLarge<float>), grid, block, 0, stream, sequence_length,
head_size,
tensor_in,
tensor_add,
tensor_out);
}
}
return HIP_CALL(hipGetLastError());
}
Status LaunchConcatTensorToTensor(hipStream_t stream,
const int all_sequence_length,
const int sequence_length,
const int batch_size,
const int head_size,
const int num_heads,
const int max_threads_per_block,
const int matrix_num,
const half* tensor_in,
const half* tensor_add,
half* tensor_out) {
const dim3 grid(all_sequence_length, batch_size, matrix_num);
if (0 == (head_size % 4)) {
const int H = head_size / 4;
if (H * num_heads <= max_threads_per_block) {
const dim3 block(H, num_heads, 1);
hipLaunchKernelGGL(HIP_KERNEL_NAME(ConcatTensorToTensor<float2>), grid, block, 0, stream, sequence_length,
reinterpret_cast<const float2*>(tensor_in),
reinterpret_cast<const float2*>(tensor_add),
reinterpret_cast<float2*>(tensor_out));
} else {
const dim3 block(max_threads_per_block / num_heads, num_heads, 1);
hipLaunchKernelGGL(HIP_KERNEL_NAME(ConcatTensorToTensorLarge<float2>), grid, block, 0, stream, sequence_length,
H,
reinterpret_cast<const float2*>(tensor_in),
reinterpret_cast<const float2*>(tensor_add),
reinterpret_cast<float2*>(tensor_out));
}
} else if (0 == (head_size & 1)) {
const int H = head_size / 2;
if (H * num_heads <= max_threads_per_block) {
const dim3 block(H, num_heads, 1);
hipLaunchKernelGGL(HIP_KERNEL_NAME(ConcatTensorToTensor<half2>), grid, block, 0, stream, sequence_length,
reinterpret_cast<const half2*>(tensor_in),
reinterpret_cast<const half2*>(tensor_add),
reinterpret_cast<half2*>(tensor_out));
} else {
const dim3 block(max_threads_per_block / num_heads, num_heads, 1);
hipLaunchKernelGGL(HIP_KERNEL_NAME(ConcatTensorToTensorLarge<half2>), grid, block, 0, stream, sequence_length,
H,
reinterpret_cast<const half2*>(tensor_in),
reinterpret_cast<const half2*>(tensor_add),
reinterpret_cast<half2*>(tensor_out));
}
} else { // this should be an "odd" case. probably not worth catching it in the half2 kernel.
if (head_size * num_heads <= max_threads_per_block) {
const dim3 block(head_size, num_heads, 1);
hipLaunchKernelGGL(HIP_KERNEL_NAME(ConcatTensorToTensor<half>), grid, block, 0, stream, sequence_length, tensor_in, tensor_add, tensor_out);
} else {
const dim3 block(max_threads_per_block / num_heads, num_heads, 1);
hipLaunchKernelGGL(HIP_KERNEL_NAME(ConcatTensorToTensorLarge<half>), grid, block, 0, stream, sequence_length,
head_size,
tensor_in,
tensor_add,
tensor_out);
}
}
return HIP_CALL(hipGetLastError());
}
Status LaunchConcatPastToPresent(hipStream_t stream,
const int all_sequence_length,
const int sequence_length,
const int batch_size,
const int head_size,
const int num_heads,
const int max_threads_per_block,
const float* past,
const float* k_v,
float* present) {
return LaunchConcatTensorToTensor(
stream,
all_sequence_length,
sequence_length,
batch_size,
head_size,
num_heads,
max_threads_per_block,
2,
past,
k_v,
present);
}
Status LaunchConcatPastToPresent(hipStream_t stream,
const int all_sequence_length,
const int sequence_length,
const int batch_size,
const int head_size,
const int num_heads,
const int max_threads_per_block,
const half* past,
const half* k_v,
half* present) {
return LaunchConcatTensorToTensor(
stream,
all_sequence_length,
sequence_length,
batch_size,
head_size,
num_heads,
max_threads_per_block,
2,
past,
k_v,
present);
}
} // namespace rocm
} // namespace contrib
} // namespace onnxruntime
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include "core/providers/rocm/shared_inc/rocm_utils.h"
#include <hip/hip_fp16.h>
#include <rocblas.h>
#include "contrib_ops/cpu/bert/attention_common.h"
namespace onnxruntime {
namespace contrib {
namespace rocm {
size_t GetAttentionScratchSize(
size_t element_size,
size_t batch_size,
size_t num_heads,
size_t sequence_length,
size_t all_sequence_length);
size_t GetAttentionWorkspaceSize(
size_t element_size,
size_t batchsize,
size_t num_heads,
size_t qk_head_size,
size_t v_head_size,
size_t sequence_length,
size_t kv_sequence_length,
size_t total_sequence_length,
void* fused_runner);
template <typename T>
struct AttentionData {
const T* gemm_buffer;
const T* bias;
const T* query;
const T* key;
const T* value;
const int* mask_index;
gsl::span<const int64_t> mask_index_dims;
const T* past;
const T* extra_add_qk;
T* workspace;
T* output;
T* present;
};
template <typename T>
Status QkvToContext(
const hipDeviceProp_t& prop,
rocblas_handle& rocblas,
hipStream_t stream,
contrib::AttentionParameters& parameters,
AttentionData<T>& data,
void* fused_runner);
Status LaunchDecoderAttentionKernel(
const hipDeviceProp_t& prop, // Device Properties
hipStream_t stream, // Cuda stream
rocblas_handle& rocblas, // Rocblas handle
const size_t element_size, // Element size of input tensor
const int batch_size, // Batch size (B)
const int sequence_length, // Sequence length (S)
const int kv_sequence_length, // Key/Value/Cache sequence length
const int num_heads, // Number of attention heads (N)
const int head_size, // Hidden size per head (H)
const bool static_kv, // Whether cross attention or not
const bool use_past, // Whether use cache or not
const bool has_layer_state, // Whether output cache or not
const bool has_key_padding_mask, // Whether use key_padding_mask or not
const void* gemm_query_buffer, // Query buffer
const void* gemm_kv_buffer, // Key and value buffer
const bool* key_padding_mask, // Key padding mask
const void* key_cache, // Input key cache
const void* value_cache, // Input value cache
void* qkv_buffer, // Temporary buffer
void* workspace_buffer, // Temporary buffer
void* output, // Output tensor
void* new_key_cache, // New_key_cache tensor
void* new_value_cache // New_value_cache tensor
);
// BxNxSxH => BxSxNxH or SxBxNxH (reversed_bs is true)
Status LaunchTransCtx(hipStream_t stream,
const int sequence_length, const int batch_size, const int head_size, const int num_heads,
const int max_threads_per_block, const bool reversed_bs, const float* input, float* output);
Status LaunchTransCtx(hipStream_t stream,
const int sequence_length, const int batch_size, const int head_size, const int num_heads,
const int max_threads_per_block, const bool reversed_bs, const half* input, half* output);
// BxSxMxNxH or SxBxMxNxH (reversed_bs is true) => MxBxNxSxH
Status LaunchTransQkv(hipStream_t stream, const int matrix_num,
const int sequence_length, const int batch_size, const int head_size, const int num_heads,
const int max_threads_per_block, const bool reversed_bs, const float* input, float* output);
Status LaunchTransQkv(hipStream_t stream, const int matrix_num,
const int sequence_length, const int batch_size, const int head_size, const int num_heads,
const int max_threads_per_block, const bool reversed_bs, const half* input, half* output);
Status LaunchConcatTensorToTensor(hipStream_t stream,
const int all_sequence_length,
const int sequence_length,
const int batch_size,
const int head_size,
const int num_heads,
const int max_threads_per_block,
const int matrix_num,
const float* tensor_in,
const float* tensor_add,
float* tensor_out);
Status LaunchConcatTensorToTensor(hipStream_t stream,
const int all_sequence_length,
const int sequence_length,
const int batch_size,
const int head_size,
const int num_heads,
const int max_threads_per_block,
const int matrix_num,
const half* tensor_in,
const half* tensor_add,
half* tensor_out);
Status LaunchConcatPastToPresent(hipStream_t stream,
const int all_sequence_length,
const int sequence_length,
const int batch_size,
const int head_size,
const int num_heads,
const int max_threads_per_block,
const float* past,
const float* k_v,
float* present);
Status LaunchConcatPastToPresent(hipStream_t stream,
const int all_sequence_length,
const int sequence_length,
const int batch_size,
const int head_size,
const int num_heads,
const int max_threads_per_block,
const half* past,
const half* k_v,
half* present);
} // namespace rocm
} // namespace contrib
} // namespace onnxruntime
#include "hip/hip_runtime.h"
/*
The implementation of this file is based on qkvToContext plugin in TensorRT demo:
https://github.com/NVIDIA/TensorRT/tree/release/5.1/demo/BERT/
Copyright 2019 NVIDIA Corporation
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
// Modifications: add transpose kernels for TRT format
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "core/providers/rocm/rocm_common.h"
#include "contrib_ops/rocm/bert/attention_impl.h"
using namespace onnxruntime::rocm;
namespace onnxruntime {
namespace contrib {
namespace rocm {
template <typename T>
__global__ void TransposeCtx(const int H, const bool reversed_bs, const T* input, T* output) {
// Input: BxNxSxH
// Output: BxSxNxH
int n = threadIdx.y;
int s = blockIdx.x;
int b = blockIdx.y;
int num_heads = blockDim.y;
int sequence_length = gridDim.x;
const int NH = num_heads * H;
const int NHS = NH * sequence_length;
const int in_offset = s * H + n * sequence_length * H + b * NHS;
int out_offset = 0;
if (reversed_bs) {
const int batch_size = gridDim.y;
const int BNH = NH * batch_size;
out_offset = n * H + b * NH + s * BNH;
} else {
out_offset = n * H + s * NH + b * NHS;
}
const int i = threadIdx.x;
if (i < H) {
output[out_offset + i] = input[in_offset + i];
}
}
template <typename T>
__global__ void TransposeCtxLarge(const int H, const bool reversed_bs, const T* input, T* output) {
// Use when (H*)*num_heads > 1024
// Input: BxNxSxH
// Output: BxSxNxH
int n = threadIdx.y;
int s = blockIdx.x;
int b = blockIdx.y;
int stride = blockDim.x;
int num_heads = blockDim.y;
int sequence_length = gridDim.x;
const int NH = num_heads * H;
const int NHS = NH * sequence_length;
const int in_offset = s * H + n * sequence_length * H + b * NHS;
int out_offset = 0;
if (reversed_bs) {
const int batch_size = gridDim.y;
const int BNH = NH * batch_size;
out_offset = n * H + b * NH + s * BNH;
} else {
out_offset = n * H + s * NH + b * NHS;
}
int i = threadIdx.x;
while (i < H) {
output[out_offset + i] = input[in_offset + i];
i += stride;
}
}
Status LaunchTransCtx(hipStream_t stream,
const int sequence_length, const int batch_size, const int head_size, const int num_heads,
const int max_threads_per_block, const bool reversed_bs, const float* input, float* output) {
const dim3 grid(sequence_length, batch_size, 1);
if (0 == (head_size & 1)) {
const int H = head_size / 2;
const float2* input2 = reinterpret_cast<const float2*>(input);
float2* output2 = reinterpret_cast<float2*>(output);
if (H * num_heads <= max_threads_per_block) {
const dim3 block(H, num_heads, 1);
hipLaunchKernelGGL(HIP_KERNEL_NAME(TransposeCtx<float2>), grid, block, 0, stream, H, reversed_bs, input2, output2);
} else {
const dim3 block(max_threads_per_block / num_heads, num_heads, 1);
hipLaunchKernelGGL(HIP_KERNEL_NAME(TransposeCtxLarge<float2>), grid, block, 0, stream, H, reversed_bs, input2, output2);
}
} else {
if (head_size * num_heads <= max_threads_per_block) {
const dim3 block(head_size, num_heads, 1);
hipLaunchKernelGGL(HIP_KERNEL_NAME(TransposeCtx<float>), grid, block, 0, stream, head_size, reversed_bs, input, output);
} else {
const dim3 block(max_threads_per_block / num_heads, num_heads, 1);
hipLaunchKernelGGL(HIP_KERNEL_NAME(TransposeCtxLarge<float>), grid, block, 0, stream, head_size, reversed_bs, input, output);
}
}
return HIP_CALL(hipGetLastError());
}
Status LaunchTransCtx(hipStream_t stream,
const int sequence_length, const int batch_size, const int head_size, const int num_heads,
const int max_threads_per_block, const bool reversed_bs, const half* input, half* output) {
const dim3 grid(sequence_length, batch_size, 1);
if (0 == (head_size % 4)) {
const int H = head_size / 4;
const float2* input2 = reinterpret_cast<const float2*>(input);
float2* output2 = reinterpret_cast<float2*>(output);
if (H * num_heads <= max_threads_per_block) {
const dim3 block(H, num_heads, 1);
hipLaunchKernelGGL(HIP_KERNEL_NAME(TransposeCtx<float2>), grid, block, 0, stream, H, reversed_bs, input2, output2);
} else {
const dim3 block(max_threads_per_block / num_heads, num_heads, 1);
hipLaunchKernelGGL(HIP_KERNEL_NAME(TransposeCtxLarge<float2>), grid, block, 0, stream, H, reversed_bs, input2, output2);
}
} else if (0 == (head_size & 1)) {
const int H = head_size / 2;
const half2* input2 = reinterpret_cast<const half2*>(input);
half2* output2 = reinterpret_cast<half2*>(output);
if (H * num_heads <= max_threads_per_block) {
const dim3 block(H, num_heads, 1);
hipLaunchKernelGGL(HIP_KERNEL_NAME(TransposeCtx<half2>), grid, block, 0, stream, H, reversed_bs, input2, output2);
} else {
const dim3 block(max_threads_per_block / num_heads, num_heads, 1);
hipLaunchKernelGGL(HIP_KERNEL_NAME(TransposeCtxLarge<half2>), grid, block, 0, stream, H, reversed_bs, input2, output2);
}
} else { // this should be an "odd" case. probably not worth catching it in the half2 kernel.
if (head_size * num_heads <= max_threads_per_block) {
const dim3 block(head_size, num_heads, 1);
hipLaunchKernelGGL(HIP_KERNEL_NAME(TransposeCtx<half>), grid, block, 0, stream, head_size, reversed_bs, input, output);
} else {
const dim3 block(max_threads_per_block / num_heads, num_heads, 1);
hipLaunchKernelGGL(HIP_KERNEL_NAME(TransposeCtxLarge<half>), grid, block, 0, stream, head_size, reversed_bs, input, output);
}
}
return HIP_CALL(hipGetLastError());
}
template <typename T>
__global__ void TransposeQKV(const int H, const bool reversed_bs, const T* input, T* output) {
// Input: BxSxKxNxH or SxBxKxNxH
// Output: KxBxNxSxH
// K is the number of identical matrix
int n = threadIdx.y;
int s = blockIdx.x;
int b = blockIdx.y;
int m = blockIdx.z; // matrix id
const int num_heads = blockDim.y;
const int sequence_length = gridDim.x;
const int batch_size = gridDim.y;
const int chunk_num = gridDim.z;
const int NH = num_heads * H;
const int NHS = NH * sequence_length;
int in_offset = 0;
if (reversed_bs) {
const int BNH = NH * batch_size;
in_offset = n * H + (m + b * chunk_num) * NH + s * BNH * chunk_num;
} else {
in_offset = n * H + (m + s * chunk_num) * NH + b * NHS * chunk_num;
}
const int out_offset = s * H + n * sequence_length * H + b * NHS + m * NHS * batch_size;
const int i = threadIdx.x;
if (i < H) {
output[out_offset + i] = input[in_offset + i];
}
}
template <typename T>
__global__ void TransposeQKVLarge(const int H, const bool reversed_bs, const T* input, T* output) {
// Use when (H*)*num_heads > 1024
// Input: BxSxKxNxH or SxBxKxNxH
// Output: KxBxNxSxH
// K is the number of identical matrix
int n = threadIdx.y;
int s = blockIdx.x;
int b = blockIdx.y;
int m = blockIdx.z; // matrix id
const int stride = blockDim.x;
const int num_heads = blockDim.y;
const int sequence_length = gridDim.x;
const int batch_size = gridDim.y;
const int chunk_num = gridDim.z;
const int NH = num_heads * H;
const int NHS = NH * sequence_length;
int in_offset = 0;
if (reversed_bs) {
const int BNH = NH * batch_size;
in_offset = n * H + (m + b * chunk_num) * NH + s * BNH * chunk_num;
} else {
in_offset = n * H + (m + s * chunk_num) * NH + b * NHS * chunk_num;
}
const int out_offset = s * H + n * sequence_length * H + b * NHS + m * NHS * batch_size;
int i = threadIdx.x;
while (i < H) {
output[out_offset + i] = input[in_offset + i];
i += stride;
}
}
Status LaunchTransQkv(hipStream_t stream, const int matrix_num,
const int sequence_length, const int batch_size, const int head_size, const int num_heads,
const int max_threads_per_block, const bool reversed_bs, const float* input, float* output) {
const dim3 grid(sequence_length, batch_size, matrix_num);
if (0 == (head_size & 1)) {
const int H = head_size / 2;
const float2* input2 = reinterpret_cast<const float2*>(input);
float2* output2 = reinterpret_cast<float2*>(output);
if (H * num_heads <= max_threads_per_block) {
const dim3 block(H, num_heads, 1);
hipLaunchKernelGGL(HIP_KERNEL_NAME(TransposeQKV<float2>), grid, block, 0, stream, H, reversed_bs, input2, output2);
} else {
const dim3 block(max_threads_per_block / num_heads, num_heads, 1);
hipLaunchKernelGGL(HIP_KERNEL_NAME(TransposeQKVLarge<float2>), grid, block, 0, stream, H, reversed_bs, input2, output2);
}
} else {
if (head_size * num_heads <= max_threads_per_block) {
const dim3 block(head_size, num_heads, 1);
hipLaunchKernelGGL(HIP_KERNEL_NAME(TransposeQKV<float>), grid, block, 0, stream, head_size, reversed_bs, input, output);
} else {
const dim3 block(max_threads_per_block / num_heads, num_heads, 1);
hipLaunchKernelGGL(HIP_KERNEL_NAME(TransposeQKVLarge<float>), grid, block, 0, stream, head_size, reversed_bs, input, output);
}
}
return HIP_CALL(hipGetLastError());
}
Status LaunchTransQkv(hipStream_t stream, const int matrix_num,
const int sequence_length, const int batch_size, const int head_size, const int num_heads,
const int max_threads_per_block, const bool reversed_bs, const half* input, half* output) {
const dim3 grid(sequence_length, batch_size, matrix_num);
if (0 == (head_size % 4)) {
const int H = head_size / 4;
const float2* input2 = reinterpret_cast<const float2*>(input);
float2* output2 = reinterpret_cast<float2*>(output);
if (H * num_heads <= max_threads_per_block) {
const dim3 block(H, num_heads, 1);
hipLaunchKernelGGL(HIP_KERNEL_NAME(TransposeQKV<float2>), grid, block, 0, stream, H, reversed_bs, input2, output2);
} else {
const dim3 block(max_threads_per_block / num_heads, num_heads, 1);
hipLaunchKernelGGL(HIP_KERNEL_NAME(TransposeQKVLarge<float2>), grid, block, 0, stream, H, reversed_bs, input2, output2);
}
} else if (0 == (head_size & 1)) {
const int H = head_size / 2;
const half2* input2 = reinterpret_cast<const half2*>(input);
half2* output2 = reinterpret_cast<half2*>(output);
if (H * num_heads <= max_threads_per_block) {
const dim3 block(H, num_heads, 1);
hipLaunchKernelGGL(HIP_KERNEL_NAME(TransposeQKV<half2>), grid, block, 0, stream, H, reversed_bs, input2, output2);
} else {
const dim3 block(max_threads_per_block / num_heads, num_heads, 1);
hipLaunchKernelGGL(HIP_KERNEL_NAME(TransposeQKVLarge<half2>), grid, block, 0, stream, H, reversed_bs, input2, output2);
}
} else { // this should be an "odd" case. probably not worth catching it in the half2 kernel..
if (head_size * num_heads <= max_threads_per_block) {
const dim3 block(head_size, num_heads, 1);
hipLaunchKernelGGL(HIP_KERNEL_NAME(TransposeQKV<half>), grid, block, 0, stream, head_size, reversed_bs, input, output);
} else {
const dim3 block(max_threads_per_block / num_heads, num_heads, 1);
hipLaunchKernelGGL(HIP_KERNEL_NAME(TransposeQKVLarge<half>), grid, block, 0, stream, head_size, reversed_bs, input, output);
}
}
return HIP_CALL(hipGetLastError());
}
} // namespace rocm
} // namespace contrib
} // namespace onnxruntime
#include "hip/hip_runtime.h"
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
// TrtSequenceOffset kernels are modified from FasterTransformer
/*
* Copyright (c) 2019-2021, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "core/providers/rocm/rocm_common.h"
#include "contrib_ops/rocm/bert/bert_padding.h"
using namespace onnxruntime::rocm;
namespace onnxruntime {
namespace contrib {
namespace rocm {
constexpr int32_t kMAX_THREADS_PER_BLOCK = 256;
// -----------------------------------
// Get indices of non-padding tokens and padding tokens. Here we assume that padding is on the right side of sequence.
// sequence_token_count is number of non-padding tokens per sequence, and it has shape [batch_size].
// For example, we have 3 sequences with 1, 2, 4 non-padding tokens and positions like the following (* means padding):
// Sequence_0: 0, 1*, 2*, 3*
// Sequence_1: 4, 5, 6*, 7*
// Sequence_2: 8, 9, 10, 11
// token_offset: 0, 4, 5, 8, 9, 10, 11, 1*, 2*, 3*, 6*, 7*
// token_count_buffer has two numbers for non-padding tokens:
// total_token_count: 1 + 2 + 4 = 7
// max_token_count: 4
// cumulated_token_count: 0, 1, 1+2, 1+2+4
__global__ void getTokenOffset(int* token_count_buffer,
int* token_offset,
int* cumulated_token_count,
const int* sequence_token_count,
const int batch_size,
const int sequence_length) {
// Find offset of non-padding tokens, and max sequence length among all batches
// TODO(tianleiwu): Use hipcub::DevicePartition::Flagged like BuildGlobalIndex in longformer_global_impl.cu
// to build token_offset when sequence length is large.
int total_tokens = 0;
int max_tokens = 0;
int index = 0;
cumulated_token_count[0] = 0;
for (int i = 0; i < batch_size; i++) {
const int count = sequence_token_count[i];
if (count > max_tokens) {
max_tokens = count;
}
cumulated_token_count[i + 1] = cumulated_token_count[i] + count;
for (int j = 0; j < count; j++) {
token_offset[index] = i * sequence_length + j;
index++;
}
total_tokens += count;
}
// Offset of paddings
for (int i = 0; i < batch_size; i++) {
const int count = sequence_token_count[i];
for (int j = 0; j < sequence_length - count; j++) {
token_offset[index] = i * sequence_length + count + j;
index++;
}
}
token_count_buffer[0] = total_tokens;
token_count_buffer[1] = max_tokens;
}
void LaunchGetTokenOffset(int* token_count_buffer,
int* token_offset,
int* cumulated_token_count,
const int* sequence_token_count,
const int batch_size,
const int sequence_length,
hipStream_t stream) {
hipLaunchKernelGGL(getTokenOffset, 1, 1, 0, stream,
token_count_buffer, token_offset, cumulated_token_count, sequence_token_count, batch_size, sequence_length);
}
// -----------------------------------
// Remove paddings
template <typename T>
__global__ void __launch_bounds__(kMAX_THREADS_PER_BLOCK)
removePadding(T* target, const T* source, const int* token_offset, const int width) {
const int tid = threadIdx.x;
const int token_index = blockIdx.x;
const int source_offset = token_offset[token_index];
const int target_offset = token_index;
for (int i = tid; i < width; i += blockDim.x) {
target[target_offset * width + i] = source[source_offset * width + i];
}
}
template <>
void LaunchRemovePadding(
half* output, const half* input, const int* token_offset, const int token_count, const int hidden_size,
hipStream_t stream) {
// input: [batch_size, sequence_length, hidden_size]
// output: [token_count, hidden_size]
// Make sure memory is aligned to 128 bit
ORT_ENFORCE(!(reinterpret_cast<size_t>(input) & 0xF) && !(reinterpret_cast<size_t>(output) & 0xF), "alignment");
if (hidden_size % 8 == 0) {
const int width = hidden_size / 8;
const int4* input2 = reinterpret_cast<const int4*>(input);
int4* output2 = reinterpret_cast<int4*>(output);
hipLaunchKernelGGL(HIP_KERNEL_NAME(removePadding<int4>), token_count, kMAX_THREADS_PER_BLOCK, 0, stream,
output2, input2, token_offset, width);
} else if (hidden_size % 4 == 0) {
const int width = hidden_size / 4;
const int64_t* input2 = reinterpret_cast<const int64_t*>(input);
int64_t* output2 = reinterpret_cast<int64_t*>(output);
hipLaunchKernelGGL(HIP_KERNEL_NAME(removePadding<int64_t>), token_count, kMAX_THREADS_PER_BLOCK, 0, stream,
output2, input2, token_offset, width);
} else if (hidden_size % 2 == 0) {
const int width = hidden_size / 2;
const int32_t* input2 = reinterpret_cast<const int32_t*>(input);
int32_t* output2 = reinterpret_cast<int32_t*>(output);
hipLaunchKernelGGL(HIP_KERNEL_NAME(removePadding<int32_t>), token_count, kMAX_THREADS_PER_BLOCK, 0, stream,
output2, input2, token_offset, width);
} else {
const int width = hidden_size;
const int16_t* input2 = reinterpret_cast<const int16_t*>(input);
int16_t* output2 = reinterpret_cast<int16_t*>(output);
hipLaunchKernelGGL(HIP_KERNEL_NAME(removePadding<int16_t>), token_count, kMAX_THREADS_PER_BLOCK, 0, stream,
output2, input2, token_offset, width);
}
}
template <>
void LaunchRemovePadding(
float* output, const float* input, const int* token_offset, const int token_count, const int hidden_size,
hipStream_t stream) {
ORT_ENFORCE(!(reinterpret_cast<size_t>(input) & 0xF) && !(reinterpret_cast<size_t>(output) & 0xF), "alignment");
if (hidden_size % 4 == 0) {
const int width = hidden_size / 4;
const int4* input2 = reinterpret_cast<const int4*>(input);
int4* output2 = reinterpret_cast<int4*>(output);
hipLaunchKernelGGL(HIP_KERNEL_NAME(removePadding<int4>), token_count, kMAX_THREADS_PER_BLOCK, 0, stream, output2, input2, token_offset, width);
} else if (hidden_size % 2 == 0) {
const int width = hidden_size / 2;
const int64_t* input2 = reinterpret_cast<const int64_t*>(input);
int64_t* output2 = reinterpret_cast<int64_t*>(output);
hipLaunchKernelGGL(HIP_KERNEL_NAME(removePadding<int64_t>), token_count, kMAX_THREADS_PER_BLOCK, 0, stream, output2, input2, token_offset, width);
} else {
const int width = hidden_size;
const int32_t* input2 = reinterpret_cast<const int32_t*>(input);
int32_t* output2 = reinterpret_cast<int32_t*>(output);
hipLaunchKernelGGL(HIP_KERNEL_NAME(removePadding<int32_t>), token_count, kMAX_THREADS_PER_BLOCK, 0, stream, output2, input2, token_offset, width);
}
}
// -----------------------------------
// Recover padding.
template <typename T>
__global__ void __launch_bounds__(kMAX_THREADS_PER_BLOCK)
restorePadding(T* target, const T* source, const int* token_offset, const int width, const int token_count) {
const int tid = threadIdx.x;
const int token_index = blockIdx.x;
const int target_seq_id = token_offset[token_index];
const int source_seq_id = token_index;
constexpr T padding_zero = 0;
if (token_index < token_count) {
for (int i = tid; i < width; i += blockDim.x) {
target[target_seq_id * width + i] = source[source_seq_id * width + i];
}
} else {
// It is padding: fill with zeros
for (int i = tid; i < width; i += blockDim.x) {
target[target_seq_id * width + i] = padding_zero;
}
}
}
template <>
__global__ void __launch_bounds__(kMAX_THREADS_PER_BLOCK)
restorePadding(int4* target, const int4* source, const int* token_offset, const int width, const int token_count) {
const int tid = threadIdx.x;
const int token_index = blockIdx.x;
const int target_seq_id = token_offset[token_index];
const int source_seq_id = token_index;
int4 padding_zero{0, 0, 0, 0};
if (token_index < token_count) {
for (int i = tid; i < width; i += blockDim.x) {
target[target_seq_id * width + i] = source[source_seq_id * width + i];
}
} else {
// It is padding: fill with zeros
for (int i = tid; i < width; i += blockDim.x) {
target[target_seq_id * width + i] = padding_zero;
}
}
}
template <>
void LaunchRestorePadding(
float* output, const float* input, const int* token_offset, const int token_count, const int hidden_size,
const int batch_size, const int sequence_length,
hipStream_t stream) {
ORT_ENFORCE(!(reinterpret_cast<size_t>(input) & 0xF) && !(reinterpret_cast<size_t>(output) & 0xF), "alignment");
int grid_size = batch_size * sequence_length;
if (hidden_size % 4 == 0) {
const int width = hidden_size / 4;
const int4* input2 = reinterpret_cast<const int4*>(input);
int4* output2 = reinterpret_cast<int4*>(output);
hipLaunchKernelGGL(HIP_KERNEL_NAME(restorePadding<int4>), grid_size, kMAX_THREADS_PER_BLOCK, 0, stream,
output2, input2, token_offset, width, token_count);
} else if (hidden_size % 2 == 0) {
const int width = hidden_size / 2;
const int64_t* input2 = reinterpret_cast<const int64_t*>(input);
int64_t* output2 = reinterpret_cast<int64_t*>(output);
hipLaunchKernelGGL(HIP_KERNEL_NAME(restorePadding<int64_t>), grid_size, kMAX_THREADS_PER_BLOCK, 0, stream,
output2, input2, token_offset, width, token_count);
} else {
const int width = hidden_size;
const int32_t* input2 = reinterpret_cast<const int32_t*>(input);
int32_t* output2 = reinterpret_cast<int32_t*>(output);
hipLaunchKernelGGL(HIP_KERNEL_NAME(restorePadding<int32_t>), grid_size, kMAX_THREADS_PER_BLOCK, 0, stream,
output2, input2, token_offset, width, token_count);
}
}
template <>
void LaunchRestorePadding(
half* output, const half* input, const int* token_offset, const int token_count, const int hidden_size,
const int batch_size, const int sequence_length,
hipStream_t stream) {
// input: [token_count, hidden_size]
// output: [batch_size, sequence_length, hidden_size]
ORT_ENFORCE(!(reinterpret_cast<size_t>(input) & 0xF) && !(reinterpret_cast<size_t>(output) & 0xF), "alignment");
int grid_size = batch_size * sequence_length;
if (hidden_size % 8 == 0) {
const int width = hidden_size / 8;
const int4* input2 = reinterpret_cast<const int4*>(input);
int4* output2 = reinterpret_cast<int4*>(output);
hipLaunchKernelGGL(HIP_KERNEL_NAME(restorePadding<int4>), grid_size, kMAX_THREADS_PER_BLOCK, 0, stream,
output2, input2, token_offset, width, token_count);
} else if (hidden_size % 4 == 0) {
const int width = hidden_size / 4;
const int64_t* input2 = reinterpret_cast<const int64_t*>(input);
int64_t* output2 = reinterpret_cast<int64_t*>(output);
hipLaunchKernelGGL(HIP_KERNEL_NAME(restorePadding<int64_t>), grid_size, kMAX_THREADS_PER_BLOCK, 0, stream,
output2, input2, token_offset, width, token_count);
} else if (hidden_size % 2 == 0) {
const int width = hidden_size / 2;
const int32_t* input2 = reinterpret_cast<const int32_t*>(input);
int32_t* output2 = reinterpret_cast<int32_t*>(output);
hipLaunchKernelGGL(HIP_KERNEL_NAME(restorePadding<int32_t>), grid_size, kMAX_THREADS_PER_BLOCK, 0, stream,
output2, input2, token_offset, width, token_count);
} else {
const int width = hidden_size;
const int16_t* input2 = reinterpret_cast<const int16_t*>(input);
int16_t* output2 = reinterpret_cast<int16_t*>(output);
hipLaunchKernelGGL(HIP_KERNEL_NAME(restorePadding<int16_t>), grid_size, kMAX_THREADS_PER_BLOCK, 0, stream,
output2, input2, token_offset, width, token_count);
}
}
__global__ void __launch_bounds__(kMAX_THREADS_PER_BLOCK)
getTrtSequenceOffset(int* trt_mha_padding_offset,
const int* sequence_token_count,
const int batch_size) {
extern __shared__ int tmp_offset[];
if (threadIdx.x == 0) {
tmp_offset[0] = 0;
for (int i = 0; i < batch_size; i++) {
tmp_offset[i + 1] = tmp_offset[i] + sequence_token_count[i];
}
}
__syncthreads();
for (int i = threadIdx.x; i < batch_size + 1; i += blockDim.x) {
trt_mha_padding_offset[i] = tmp_offset[i];
}
}
// Get sequence offset for TensorRT fused attention when there is no padding (or padding is removed)
void LaunchTrtSequenceOffset(int* trt_mha_padding_offset,
const int* sequence_token_count,
const int batch_size,
hipStream_t stream) {
hipLaunchKernelGGL(getTrtSequenceOffset, 1, kMAX_THREADS_PER_BLOCK, sizeof(int) * (batch_size + 1), stream,
trt_mha_padding_offset, sequence_token_count, batch_size);
}
__global__ void __launch_bounds__(kMAX_THREADS_PER_BLOCK)
getTrtSequenceOffset(int* trt_mha_padding_offset,
const int* sequence_token_count,
const int batch_size,
const int sequence_length) {
extern __shared__ int tmp_offset[];
if (threadIdx.x == 0) {
tmp_offset[0] = 0;
// B for fused attention is 2 * batch_size
for (int i = 0; i < batch_size; i++) {
tmp_offset[i * 2 + 1] = tmp_offset[i * 2] + sequence_token_count[i];
tmp_offset[i * 2 + 2] = sequence_length * (i + 1);
}
}
__syncthreads();
for (int i = threadIdx.x; i < 2 * batch_size + 1; i += blockDim.x) {
trt_mha_padding_offset[i] = tmp_offset[i];
}
}
// Get sequence offset for TensorRT fused attention when we keep the padding
void LaunchTrtSequenceOffset(int* trt_mha_padding_offset,
const int* sequence_token_count,
const int batch_size,
const int sequence_length,
hipStream_t stream) {
hipLaunchKernelGGL(getTrtSequenceOffset, 1, kMAX_THREADS_PER_BLOCK, sizeof(int) * (2 * batch_size + 1), stream,
trt_mha_padding_offset, sequence_token_count, batch_size, sequence_length);
}
} // namespace rocm
} // namespace contrib
} // namespace onnxruntime
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include "core/providers/rocm/shared_inc/rocm_utils.h"
#include <hip/hip_fp16.h>
#include <rocblas.h>
namespace onnxruntime {
namespace contrib {
namespace rocm {
// Build token indice for non-padding tokens and padding tokens.
void LaunchGetTokenOffset(int* token_count_buffer,
int* token_offset,
int* cumulated_token_count,
const int* sequence_token_count,
const int batch_size,
const int sequence_length,
hipStream_t stream);
// Remove paddings from input.
template <typename T>
void LaunchRemovePadding(
T* output, const T* input, const int* token_offset, const int token_count, const int hidden_size,
hipStream_t stream);
// Rebuild paddings to restore output shape.
template <typename T>
void LaunchRestorePadding(
T* output, const T* input, const int* token_offset, const int token_count, const int hidden_size,
const int batch_size, const int sequence_length,
hipStream_t stream);
// Padding offset for TensorRT fused attention kernel
void LaunchTrtSequenceOffset(int* trt_mha_padding_offset,
const int* mask_index,
const int batch_size,
hipStream_t stream);
void LaunchTrtSequenceOffset(int* trt_mha_padding_offset,
const int* mask_index,
const int batch_size,
const int sequence_length,
hipStream_t stream);
} // namespace rocm
} // namespace contrib
} // namespace onnxruntime
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "contrib_ops/rocm/bert/attention_impl.h"
#include "contrib_ops/rocm/bert/decoder_attention.h"
#include "contrib_ops/rocm/bert/transformer_rocm_common.h"
#include "core/framework/op_kernel.h"
#include "core/providers/rocm/shared_inc/fpgeneric.h"
using namespace onnxruntime::rocm;
using namespace ::onnxruntime::common;
using namespace ONNX_NAMESPACE;
namespace onnxruntime {
namespace contrib {
namespace rocm {
#define REGISTER_KERNEL_TYPED(T) \
ONNX_OPERATOR_TYPED_KERNEL_EX( \
DecoderAttention, \
kMSDomain, \
1, \
T, \
kRocmExecutionProvider, \
(*KernelDefBuilder::Create()) \
.TypeConstraint("T", DataTypeImpl::GetTensorType<T>()), \
DecoderAttention<T>);
REGISTER_KERNEL_TYPED(float)
REGISTER_KERNEL_TYPED(MLFloat16)
namespace {
Status CheckInputs(const TensorShape& query_shape,
const TensorShape& key_shape,
const TensorShape& q_weights_shape,
const TensorShape& kv_weights_shape,
const TensorShape& bias_shape,
const Tensor* key_padding_mask,
const Tensor* key_cache,
const Tensor* value_cache,
const bool static_kv,
const bool use_past,
const bool has_layer_state,
const bool has_key_padding_mask) {
const auto& query_shape_dims = query_shape.GetDims();
if (query_shape_dims.size() != 3) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'query' is expected to have 3 dimensions, got ",
query_shape_dims.size());
}
int sequence_length = static_cast<int>(query_shape_dims[0]);
int batch_size = static_cast<int>(query_shape_dims[1]);
int hidden_size = static_cast<int>(query_shape_dims[2]);
const auto& key_shape_dims = key_shape.GetDims();
if (key_shape_dims.size() != 3) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'key' is expected to have 3 dimensions, got ",
key_shape_dims.size());
}
int kv_sequence_length = static_cast<int>(key_shape_dims[0]);
if (query_shape_dims[1] != key_shape_dims[1]) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "query and key shall have the same batch size");
}
if (query_shape_dims[2] != key_shape_dims[2]) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "query and key shall have the same hidden size");
}
const auto& q_weights_dims = q_weights_shape.GetDims();
if (q_weights_dims.size() != 2) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'q_weights' is expected to have 2 dimensions, got ",
q_weights_dims.size());
}
const auto& kv_weights_dims = kv_weights_shape.GetDims();
if (kv_weights_dims.size() != 2) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'kv_weights' is expected to have 2 dimensions, got ",
kv_weights_dims.size());
}
if (q_weights_dims[0] != hidden_size || q_weights_dims[1] != hidden_size) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "q_weights shall have shape (hidden size, hidden size)");
}
if (kv_weights_dims[0] != hidden_size || kv_weights_dims[1] != 2 * static_cast<int64_t>(hidden_size)) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "kv_weights shall have shape (hidden size, 2 * hidden size)");
}
const auto& bias_dims = bias_shape.GetDims();
if (bias_dims.size() != 1) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'bias' is expected to have 1 dimension, got ",
bias_dims.size());
}
if (bias_dims[0] != 3 * static_cast<int64_t>(hidden_size)) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "bias shall have shape (3 * hidden size)");
}
int key_length = kv_sequence_length;
if (key_padding_mask != nullptr && has_key_padding_mask == true) {
const auto& kp_mask_dims = key_padding_mask->Shape().GetDims();
if (kp_mask_dims.size() != 2) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"Input 'key_padding_mask' is expected to have 2 dimension, got ",
kp_mask_dims.size());
}
if (kp_mask_dims[0] != batch_size) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "key_padding_mask shall have same batch size with query");
}
if (!has_layer_state || !use_past) {
if (!static_kv) {
key_length = sequence_length;
}
} else {
if (!static_kv) {
key_length = sequence_length + kv_sequence_length;
}
}
if (kp_mask_dims[1] != key_length) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"key_padding_mask shall have same sequence length as generated key");
}
}
if (key_cache != nullptr && value_cache != nullptr && has_layer_state && use_past) {
const auto& key_cache_dims = key_cache->Shape().GetDims();
if (key_cache_dims.size() != 4) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'key_cache' is expected to have 4 dimension, got ",
key_cache_dims.size());
}
const auto& value_cache_dims = value_cache->Shape().GetDims();
if (value_cache_dims.size() != 4) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'value_cache' is expected to have 4 dimension, got ",
value_cache_dims.size());
}
if (key_cache_dims[0] != batch_size) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "key_cache shall have same batch size as query");
}
if (value_cache_dims[0] != batch_size) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "value_cache shall have same batch size as query");
}
if (key_cache_dims[1] * key_cache_dims[3] != hidden_size) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "key_cache shall have correct hidden size");
}
if (value_cache_dims[1] * value_cache_dims[3] != hidden_size) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "value_cache shall have correct hidden size");
}
}
return Status::OK();
}
} // anonymous namespace
template <typename T>
DecoderAttention<T>::DecoderAttention(const OpKernelInfo& info) : RocmKernel(info) {
int64_t num_heads = 0;
ORT_ENFORCE(info.GetAttr("num_heads", &num_heads).IsOK() && num_heads > 0);
num_heads_ = static_cast<int>(num_heads);
}
template <typename T>
Status DecoderAttention<T>::ComputeInternal(OpKernelContext* context) const {
const Tensor* query(context->Input<Tensor>(0));
const Tensor* key(context->Input<Tensor>(1));
const Tensor* q_weights(context->Input<Tensor>(2));
const Tensor* kv_weights(context->Input<Tensor>(3));
const Tensor* bias(context->Input<Tensor>(4));
const Tensor* key_padding_mask(context->Input<Tensor>(5));
const Tensor* key_cache(context->Input<Tensor>(6));
const Tensor* value_cache(context->Input<Tensor>(7));
const Tensor* static_kv(context->Input<Tensor>(8));
const Tensor* use_past(context->Input<Tensor>(9));
const Tensor* has_layer_state(context->Input<Tensor>(10));
const Tensor* has_key_padding_mask(context->Input<Tensor>(11));
hipStream_t stream = Stream();
// Copy static_kv, use_past and has_layer_state to CPU
auto pinned_buffer = AllocateBufferOnCPUPinned<void>(4 * sizeof(bool));
bool* kernel_state_pinned = reinterpret_cast<bool*>(pinned_buffer.get());
HIP_RETURN_IF_ERROR(hipMemcpyAsync(kernel_state_pinned, static_kv->Data<bool>(), sizeof(bool),
hipMemcpyDeviceToHost, stream));
HIP_RETURN_IF_ERROR(hipMemcpyAsync(kernel_state_pinned + 1, use_past->Data<bool>(), sizeof(bool),
hipMemcpyDeviceToHost, stream));
HIP_RETURN_IF_ERROR(hipMemcpyAsync(kernel_state_pinned + 2, has_layer_state->Data<bool>(), sizeof(bool),
hipMemcpyDeviceToHost, stream));
HIP_RETURN_IF_ERROR(hipMemcpyAsync(kernel_state_pinned + 3, has_key_padding_mask->Data<bool>(), sizeof(bool),
hipMemcpyDeviceToHost, stream));
// Create an event to make sure the async copy is finished before reading the data.
AutoDestoryCudaEvent new_event;
hipEvent_t& isCopyDone = new_event.Get();
HIP_RETURN_IF_ERROR(hipEventCreate(&isCopyDone));
HIP_RETURN_IF_ERROR(hipEventRecord(isCopyDone, stream));
auto& device_prop = GetDeviceProp();
// query shape (batch_size, sequence_length, input_hidden_size)
const auto& query_shape = query->Shape();
int sequence_length = static_cast<int>(query_shape[0]);
int batch_size = static_cast<int>(query_shape[1]);
int hidden_size = static_cast<int>(query_shape[2]);
const auto& key_shape = key->Shape();
int key_sequence_length = static_cast<int>(key_shape[0]);
int head_size = hidden_size / num_heads_;
//k, v sequence after gemm
int kv_sequence_length = 0;
// Generate q, k, v w/o cache
// query input: (S, B, h1)
// key input: (S', B, h1)
// weight: (h1, h2)
// h = N*H
rocblas_handle rocblas = RocblasHandle();
ROCBLAS_RETURN_IF_ERROR(rocblas_set_stream(rocblas, stream));
constexpr size_t element_size = sizeof(T);
typedef typename ToHipType<T>::MappedType HipT;
HipT one = ToHipType<T>::FromFloat(1.0f);
HipT zero = ToHipType<T>::FromFloat(0.0f);
int m = 0, n = 0, k = 0;
IAllocatorUniquePtr<T> gemm_query_buffer_p(nullptr);
IAllocatorUniquePtr<T> gemm_kv_buffer_p(nullptr);
HIP_RETURN_IF_ERROR(hipEventSynchronize(isCopyDone));
bool static_kv_ = *kernel_state_pinned;
bool use_past_ = *(kernel_state_pinned + 1);
bool has_layer_state_ = *(kernel_state_pinned + 2);
bool has_key_padding_mask_ = *(kernel_state_pinned + 3);
ORT_RETURN_IF_ERROR(
CheckInputs(query->Shape(),
key->Shape(),
q_weights->Shape(),
kv_weights->Shape(),
bias->Shape(),
key_padding_mask,
key_cache,
value_cache,
static_kv_,
use_past_,
has_layer_state_,
has_key_padding_mask_)
);
// calculate q
gemm_query_buffer_p = GetScratchBuffer<T>(batch_size * sequence_length * hidden_size * element_size);
m = sequence_length * batch_size;
n = hidden_size;
k = hidden_size;
// TODO(tianleiwu): fuse bias and transpose
// broadcast bias for query: (h2, S*B)
ROCBLAS_RETURN_IF_ERROR(rocblasGemmHelper(
rocblas, rocblas_operation_none, rocblas_operation_none, n, m, 1, &one,
reinterpret_cast<const HipT*>(bias->Data<T>()), n,
GetConstOnes<HipT>(m), 1,
&zero, reinterpret_cast<HipT*>(gemm_query_buffer_p.get()), n, device_prop));
// matmul: (h2, h1)*(h1, S*B)
ROCBLAS_RETURN_IF_ERROR(rocblasGemmHelper(
rocblas, rocblas_operation_none, rocblas_operation_none, n, m, k, &one,
reinterpret_cast<const HipT*>(q_weights->Data<T>()), n,
reinterpret_cast<const HipT*>(query->Data<T>()), k,
&one, reinterpret_cast<HipT*>(gemm_query_buffer_p.get()), n, device_prop));
// gemm_query_buffer in col-base: (h2, S*B)
// calcualte k, v
n = 2 * hidden_size;
k = hidden_size;
if (!has_layer_state_ || !use_past_) {
if (!static_kv_) {
gemm_kv_buffer_p = GetScratchBuffer<T>(batch_size * 2 * sequence_length * hidden_size * element_size);
m = sequence_length * batch_size;
n = 2 * hidden_size;
k = hidden_size;
kv_sequence_length = sequence_length;
// broadcast bias for key and value: (2*h2, T_S*B)
ROCBLAS_RETURN_IF_ERROR(rocblasGemmHelper(
rocblas, rocblas_operation_none, rocblas_operation_none, n, m, 1, &one,
reinterpret_cast<const HipT*>(bias->Data<T>() + hidden_size), n,
GetConstOnes<HipT>(m), 1,
&zero, reinterpret_cast<HipT*>(gemm_kv_buffer_p.get()), n, device_prop));
// matmul: (2*h2, h1)*(h1, T_S*B)
ROCBLAS_RETURN_IF_ERROR(rocblasGemmHelper(
rocblas, rocblas_operation_none, rocblas_operation_none, n, m, k, &one,
reinterpret_cast<const HipT*>(kv_weights->Data<T>()), n,
reinterpret_cast<const HipT*>(query->Data<T>()), k,
&one, reinterpret_cast<HipT*>(gemm_kv_buffer_p.get()), n, device_prop));
// gemm_kv_buffer in col-base: (2*h2, T_S*B)
} else {
gemm_kv_buffer_p = GetScratchBuffer<T>(batch_size * 2 * key_sequence_length * hidden_size * element_size);
m = key_sequence_length * batch_size;
n = 2 * hidden_size;
k = hidden_size;
kv_sequence_length = key_sequence_length;
// broadcast bias for key and value: (2*h2, T_S*B)
ROCBLAS_RETURN_IF_ERROR(rocblasGemmHelper(
rocblas, rocblas_operation_none, rocblas_operation_none, n, m, 1, &one,
reinterpret_cast<const HipT*>(bias->Data<T>() + hidden_size), n,
GetConstOnes<HipT>(m), 1,
&zero, reinterpret_cast<HipT*>(gemm_kv_buffer_p.get()), n, device_prop));
// matmul: (2*h2, h1)*(h1, T_S*B)
ROCBLAS_RETURN_IF_ERROR(rocblasGemmHelper(
rocblas, rocblas_operation_none, rocblas_operation_none, n, m, k, &one,
reinterpret_cast<const HipT*>(kv_weights->Data<T>()), n,
reinterpret_cast<const HipT*>(key->Data<T>()), k,
&one, reinterpret_cast<HipT*>(gemm_kv_buffer_p.get()), n, device_prop));
// gemm_kv_buffer in col-base: (2*h2, T_S*B)
}
} else {
ORT_ENFORCE(nullptr != key_cache && nullptr != value_cache); // (B, N, S, H)
const auto& cache_shape = key_cache->Shape();
// key and value cache have identical shape
int cache_sequence_length = static_cast<int>(cache_shape[2]);
if (!static_kv_) {
gemm_kv_buffer_p = GetScratchBuffer<T>(batch_size * 2 * sequence_length * hidden_size * element_size);
m = sequence_length * batch_size;
kv_sequence_length = cache_sequence_length + sequence_length;
// broadcast bias for key and value: (2*h2, T_S*B)
ROCBLAS_RETURN_IF_ERROR(rocblasGemmHelper(
rocblas, rocblas_operation_none, rocblas_operation_none, n, m, 1, &one,
reinterpret_cast<const HipT*>(bias->Data<T>() + hidden_size), n,
GetConstOnes<HipT>(m), 1,
&zero, reinterpret_cast<HipT*>(gemm_kv_buffer_p.get()), n, device_prop));
// matmul: (2*h2, h1)*(h1, T_S*B)
ROCBLAS_RETURN_IF_ERROR(rocblasGemmHelper(
rocblas, rocblas_operation_none, rocblas_operation_none, n, m, k, &one,
reinterpret_cast<const HipT*>(kv_weights->Data<T>()), n,
reinterpret_cast<const HipT*>(query->Data<T>()), k,
&one, reinterpret_cast<HipT*>(gemm_kv_buffer_p.get()), n, device_prop));
// gemm_kv_buffer in col-base: (2*h2, T_S*B)
} else {
kv_sequence_length = cache_sequence_length;
}
}
size_t bytes = element_size * batch_size * (sequence_length + 2 * kv_sequence_length) * hidden_size;
auto qkv_buffer_p = GetScratchBuffer<void>(bytes);
bytes = element_size * 2 * batch_size * sequence_length * num_heads_ * (2 * head_size + kv_sequence_length);
auto workspace_p = GetScratchBuffer<void>(bytes);
Tensor* output(context->Output(0, query_shape));
TensorShape new_cache_shape({batch_size, num_heads_, kv_sequence_length, head_size});
Tensor* new_key_cache(context->Output(1, new_cache_shape));
Tensor* new_value_cache(context->Output(2, new_cache_shape));
return LaunchDecoderAttentionKernel(
device_prop,
stream,
rocblas,
element_size,
batch_size,
sequence_length,
kv_sequence_length,
num_heads_,
head_size,
static_kv_,
use_past_,
has_layer_state_,
has_key_padding_mask_,
nullptr == gemm_query_buffer_p ? nullptr : reinterpret_cast<const HipT*>(gemm_query_buffer_p.get()),
nullptr == gemm_kv_buffer_p ? nullptr : reinterpret_cast<const HipT*>(gemm_kv_buffer_p.get()),
nullptr == key_padding_mask ? nullptr : key_padding_mask->Data<bool>(),
nullptr == key_cache ? nullptr : key_cache->Data<T>(),
nullptr == value_cache ? nullptr : value_cache->Data<T>(),
qkv_buffer_p.get(),
workspace_p.get(),
output->MutableData<T>(),
nullptr == new_key_cache ? nullptr : new_key_cache->MutableData<T>(),
nullptr == new_value_cache ? nullptr : new_value_cache->MutableData<T>());
}
} // namespace rocm
} // namespace contrib
} // namespace onnxruntime
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include "core/common/common.h"
#include "core/providers/rocm/rocm_kernel.h"
namespace onnxruntime {
namespace contrib {
namespace rocm {
using namespace onnxruntime::rocm;
template <typename T>
class DecoderAttention final : public RocmKernel {
public:
DecoderAttention(const OpKernelInfo& info);
Status ComputeInternal(OpKernelContext* context) const override;
private:
int num_heads_;
};
} // namespace rocm
} // namespace contrib
} // namespace onnxruntime
#include "hip/hip_runtime.h"
/*
The implementation of this file is based on bert plugins in TensorRT demo:
https://github.com/NVIDIA/TensorRT/tree/release/5.1/demo/BERT/
Copyright 2019 NVIDIA Corporation
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include "core/providers/rocm/rocm_common.h"
#include "core/providers/rocm/cu_inc/common.cuh"
#include "core/providers/rocm/shared_inc/rocm_call.h"
#include <hip/hip_fp16.h>
#include <rocblas.h>
#include <hipcub/hipcub.hpp>
using namespace onnxruntime::rocm;
using namespace hipcub;
namespace onnxruntime {
namespace contrib {
namespace rocm {
template <typename T>
__device__ inline T Rsqrt(const T& x);
template <>
__device__ inline float Rsqrt(const float& x) {
return rsqrtf(x);
}
template <>
__device__ inline half Rsqrt(const half& x) {
#if __CUDA_ARCH__ >= 530 || !defined(__CUDA_ARCH__)
return hrsqrt(x);
#else
return half(rsqrtf(float(x)));
#endif
}
__device__ inline half2 AddHalf2(const half2 a, const half2 b) {
#if __CUDA_ARCH__ >= 530 || !defined(__CUDA_ARCH__)
return __hadd2(a, b);
#else
return __halves2half2(__hadd(a.x, b.x), __hadd(a.y, b.y));
#endif
}
struct KeyValuePairSum {
__device__ inline hipcub::KeyValuePair<float, float> operator()(const hipcub::KeyValuePair<float, float>& a,
const hipcub::KeyValuePair<float, float>& b) {
return hipcub::KeyValuePair<float, float>(a.key + b.key, a.value + b.value);
}
__device__ inline hipcub::KeyValuePair<half, half> operator()(const hipcub::KeyValuePair<half, half>& a,
const hipcub::KeyValuePair<half, half>& b) {
const half2 a2 = __halves2half2(a.key, a.value);
const half2 b2 = __halves2half2(b.key, b.value);
const half2 res = AddHalf2(a2, b2);
return hipcub::KeyValuePair<half, half>(__low2half(res), __high2half(res));
}
__device__ inline hipcub::KeyValuePair<half2, half2> operator()(const hipcub::KeyValuePair<half2, half2>& a,
const hipcub::KeyValuePair<half2, half2>& b) {
return hipcub::KeyValuePair<half2, half2>(AddHalf2(a.key, b.key), AddHalf2(a.value, b.value));
}
};
template <typename T, int TPB>
__device__ inline void LayerNorm(
const hipcub::KeyValuePair<T, T>& thread_data, const int ld, const int offset, const T* beta,
const T* gamma, const T epsilon, T* output) {
// Assuming thread_data is already divided by ld
using BlockReduce = hipcub::BlockReduce<hipcub::KeyValuePair<T, T>, TPB>;
__shared__ typename BlockReduce::TempStorage temp_storage;
__shared__ T mu; // mean
__shared__ T rsigma; // 1 / std.dev.
KeyValuePairSum pair_sum;
const auto sum_kv = BlockReduce(temp_storage).Reduce(thread_data, pair_sum);
if (threadIdx.x == 0) {
mu = sum_kv.key;
rsigma = Rsqrt(sum_kv.value - mu * mu + epsilon);
}
__syncthreads();
for (int i = threadIdx.x; i < ld; i += TPB) {
const int idx = offset + i;
const T val = output[idx];
const T g(gamma[i]);
const T b = (nullptr == beta) ? (T)0 : beta[i];
output[idx] = g * (val - mu) * rsigma + b;
}
}
template <typename T, int TPB, int ILP>
__device__ inline void LayerNormSmall(const T* input_v, const hipcub::KeyValuePair<T, T>& thread_data,
const int ld, const int idx, const T* beta, const T* gamma,
const T epsilon, T* output) {
// Assuming thread_data is already divided by ld
// Small settings: the block covers the leading dimension TPB >= ld. The input
// value is available in a register
using VecT = aligned_vector<T, ILP>;
using BlockReduce = hipcub::BlockReduce<hipcub::KeyValuePair<T, T>, TPB>;
__shared__ typename BlockReduce::TempStorage temp_storage;
__shared__ T mu; // mean
__shared__ T rsigma; // 1 / std.dev.
T beta_v[ILP], gamma_v[ILP], output_v[ILP];
if (beta != nullptr) {
VecT* beta_val = reinterpret_cast<VecT*>(&beta_v);
*beta_val = *reinterpret_cast<const VecT*>(&beta[threadIdx.x * ILP]);
}
VecT* gamma_val = reinterpret_cast<VecT*>(&gamma_v);
*gamma_val = *reinterpret_cast<const VecT*>(&gamma[threadIdx.x * ILP]);
VecT* output_val = reinterpret_cast<VecT*>(&output_v);
KeyValuePairSum pair_sum;
const hipcub::KeyValuePair<T, T> sum_kv = BlockReduce(temp_storage).Reduce(thread_data, pair_sum);
if (threadIdx.x == 0) {
mu = sum_kv.key;
rsigma = Rsqrt(sum_kv.value - mu * mu + epsilon);
}
__syncthreads();
if (ILP * threadIdx.x < ld) {
#pragma unroll
for (int i = 0; i < ILP; i++) {
output_v[i] = (beta != nullptr)
? gamma_v[i] * (input_v[i] - mu) * rsigma + beta_v[i]
: gamma_v[i] * (input_v[i] - mu) * rsigma;
}
*(reinterpret_cast<VecT*>(&output[idx])) = *output_val;
}
}
} // namespace rocm
} // namespace contrib
} // namespace onnxruntime
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "core/providers/rocm/shared_inc/fpgeneric.h"
#include "core/platform/env_var_utils.h"
#include "contrib_ops/rocm/bert/longformer_global_impl.h"
#include "contrib_ops/rocm/bert/longformer_attention_impl.h"
#include "contrib_ops/rocm/bert/transformer_rocm_common.h"
#include "contrib_ops/rocm/bert/longformer_attention.h"
using namespace onnxruntime::rocm;
using namespace ::onnxruntime::common;
using namespace ONNX_NAMESPACE;
namespace onnxruntime {
namespace contrib {
namespace rocm {
#define REGISTER_KERNEL_TYPED(T) \
ONNX_OPERATOR_TYPED_KERNEL_EX( \
LongformerAttention, \
kMSDomain, \
1, \
T, \
kRocmExecutionProvider, \
(*KernelDefBuilder::Create()) \
.TypeConstraint("T", DataTypeImpl::GetTensorType<T>()), \
LongformerAttention<T>);
REGISTER_KERNEL_TYPED(float)
REGISTER_KERNEL_TYPED(MLFloat16)
template <typename T>
LongformerAttention<T>::LongformerAttention(const OpKernelInfo& info)
: RocmKernel(info), LongformerAttentionBase(info) {
use_compact_memory_ = ParseEnvironmentVariableWithDefault<bool>(longformer::kUseCompactMemory, true);
use_half4_ = ParseEnvironmentVariableWithDefault<bool>(longformer::kUseHalf4, true);
}
template <typename T>
Status LongformerAttention<T>::ComputeInternal(OpKernelContext* context) const {
const Tensor* input = context->Input<Tensor>(0);
const Tensor* weights = context->Input<Tensor>(1);
const Tensor* bias = context->Input<Tensor>(2);
const Tensor* attention_mask = context->Input<Tensor>(3);
const Tensor* global_weights = context->Input<Tensor>(4);
const Tensor* global_bias = context->Input<Tensor>(5);
const Tensor* global_attention_mask = context->Input<Tensor>(6);
ORT_RETURN_IF_ERROR(CheckInputs(input->Shape(), weights->Shape(), bias->Shape(), attention_mask->Shape(),
global_weights->Shape(), global_bias->Shape(), global_attention_mask->Shape()));
// Input shapes:
// input : (batch_size, sequence_length, hidden_size)
// weights : (hidden_size, 3 * hidden_size) -- format 1
// (3, hidden_size, hidden_size) -- format 0
// bias : (3 * hidden_size) -- format 1 (bias for Q, K, V)
// (5 * hidden_size) -- format 0 (bias for Q, K, V, Global_K, Global_V)
// attention_mask : (batch_size, sequence_length)
// global_weights : (hidden_size, 3 * hidden_size) -- format 1
// (3, hidden_size, hidden_size) -- format 0
// global_bias : (3 * hidden_size) -- format 1 (bias for Global_Q, Global_K, Global_V)
// (1 * hidden_size) -- format 0 (bias for Global_Q)
// global_attention_mask : (batch_size, sequence_length)
// Output shapes:
// output : (batch_size, sequence_length, hidden_size)
const auto& shape = input->Shape();
int batch_size = static_cast<int>(shape[0]);
int sequence_length = static_cast<int>(shape[1]);
int hidden_size = static_cast<int>(shape[2]);
int head_size = hidden_size / num_heads_;
Tensor* output = context->Output(0, shape);
rocblas_handle rocblas = RocblasHandle();
hipStream_t stream = Stream();
ROCBLAS_RETURN_IF_ERROR(rocblas_set_stream(rocblas, stream));
constexpr size_t element_size = sizeof(T);
// TODO(tianleiwu): only calculate global index once per model instead of once per LongformerAttention node.
// Build Global Index
auto global_index_buffer = GetScratchBuffer<int>(static_cast<size_t>(batch_size) * sequence_length);
auto batch_global_num_buffer = GetScratchBuffer<int>(batch_size);
size_t global_scratch_bytes = GetGlobalScratchSize(sequence_length);
auto global_scratch_buffer = GetScratchBuffer<void>(global_scratch_bytes);
auto& device_prop = GetDeviceProp();
ORT_RETURN_IF_ERROR(BuildGlobalIndex(
device_prop,
stream,
global_attention_mask->Data<int>(),
batch_size,
sequence_length,
global_index_buffer.get(),
batch_global_num_buffer.get(),
global_scratch_buffer.get(),
global_scratch_bytes));
// Copy batch_global_num to CPU
size_t pinned_buffer_bytes = GetPinnedBufferSize(batch_size);
auto pinned_buffer = AllocateBufferOnCPUPinned<void>(pinned_buffer_bytes);
int* batch_global_num_pinned = reinterpret_cast<int*>(pinned_buffer.get());
HIP_RETURN_IF_ERROR(hipMemcpyAsync(batch_global_num_pinned,
batch_global_num_buffer.get(),
batch_size * sizeof(int),
hipMemcpyDeviceToHost,
stream));
// Create an event to make sure the async copy is finished before reading the data.
AutoDestoryCudaEvent new_event;
hipEvent_t& is_copy_done = new_event.Get();
HIP_RETURN_IF_ERROR(hipEventCreateWithFlags(&is_copy_done, hipEventDisableTiming));
HIP_RETURN_IF_ERROR(hipEventRecord(is_copy_done, stream));
size_t qkv_size = batch_size * sequence_length * 3 * hidden_size * element_size;
// Buffer for GEMM outputs of q, k, v, global_q, global_k and global_v
// TODO(tianleiwu): compact global_q only need batch_size * window * hidden_size * element_size buffer size.
auto gemm_buffer = GetScratchBuffer<void>(qkv_size + qkv_size);
bool use_merged_qkv_weights = (weights->Shape().NumDimensions() == 2);
int m = batch_size * sequence_length;
int n = use_merged_qkv_weights ? 3 * hidden_size : hidden_size;
int k = hidden_size;
typedef typename ToHipType<T>::MappedType HipT;
const HipT* input_data = reinterpret_cast<const HipT*>(input->Data<T>());
const HipT* weights_data = reinterpret_cast<const HipT*>(weights->Data<T>());
const HipT* global_weights_data = reinterpret_cast<const HipT*>(global_weights->Data<T>());
float one = 1.0f;
float zero = 0.0f;
if (use_merged_qkv_weights) {
// Gemm, note that ROCM assumes col-major, so result(N, M) = 1 * weights x input + 0 x B.
ROCBLAS_RETURN_IF_ERROR(rocblasGemmHelper(
rocblas, rocblas_operation_none, rocblas_operation_none, n, m, k, &one,
weights_data, n,
input_data, k,
&zero, reinterpret_cast<HipT*>(gemm_buffer.get()), n, device_prop));
} else {
// q
const HipT* q_weight = weights_data;
HipT* q_data = reinterpret_cast<HipT*>(gemm_buffer.get());
ROCBLAS_RETURN_IF_ERROR(rocblasGemmHelper(
rocblas, rocblas_operation_none, rocblas_operation_none, n, m, k, &one,
q_weight, n,
input_data, k,
&zero, q_data, n, device_prop));
// k
const HipT* k_weight = q_weight + hidden_size * hidden_size;
HipT* k_data = q_data + batch_size * sequence_length * hidden_size;
ROCBLAS_RETURN_IF_ERROR(rocblasGemmHelper(
rocblas, rocblas_operation_none, rocblas_operation_none, n, m, k, &one,
k_weight, n,
input_data, k,
&zero, k_data, n, device_prop));
// v
const HipT* v_weight = k_weight + hidden_size * hidden_size;
HipT* v_data = k_data + batch_size * sequence_length * hidden_size;
ROCBLAS_RETURN_IF_ERROR(rocblasGemmHelper(
rocblas, rocblas_operation_none, rocblas_operation_none, n, m, k, &one,
v_weight, n,
input_data, k,
&zero, v_data, n, device_prop));
}
// Wait for async copy of batch_global_num
HIP_RETURN_IF_ERROR(hipEventSynchronize(is_copy_done));
// Find the maximum number of global tokens in all batches
int max_num_global = 0;
for (int i = 0; i < batch_size; ++i) {
if (max_num_global < batch_global_num_pinned[i]) {
max_num_global = batch_global_num_pinned[i];
}
}
// Do not use compact memory kernel in the following situations:
// (1) global tokens > windows size, compact memory kernel cannot be used due to its assumptions.
// (2) sequence_length == 2 * attention_window, compact memory kernel has parity issue.
// (3) user sets environment variable ORT_LONGFORMER_COMPACT_MEMORY=0
bool disable_compact_memory = (max_num_global > window_ || sequence_length == 2 * window_ || !use_compact_memory_);
// Fully connection for global projection.
// Note that Q only need handle global query tokens if we split GEMM to global Q/K/V separately.
// When there is no global token, need not run global GEMM.
HipT* global_gemm_buffer = nullptr;
if (max_num_global > 0) {
global_gemm_buffer = reinterpret_cast<HipT*>(reinterpret_cast<char*>(gemm_buffer.get()) + qkv_size);
if (use_merged_qkv_weights) {
ROCBLAS_RETURN_IF_ERROR(rocblasGemmHelper(
rocblas, rocblas_operation_none, rocblas_operation_none, n, m, k, &one,
reinterpret_cast<const HipT*>(global_weights->Data<T>()), n,
input_data, k,
&zero, global_gemm_buffer, n, device_prop));
} else {
// global q
const HipT* global_q_weight = global_weights_data;
HipT* global_q = global_gemm_buffer + 2 * batch_size * sequence_length * hidden_size;
if (disable_compact_memory) {
ROCBLAS_RETURN_IF_ERROR(rocblasGemmHelper(
rocblas, rocblas_operation_none, rocblas_operation_none, n, m, k, &one,
global_q_weight, n,
input_data, k,
&zero, global_q, n, device_prop));
} else {
ROCBLAS_RETURN_IF_ERROR(rocblasGemmStridedBatchedHelper(rocblas,
rocblas_operation_none,
rocblas_operation_none,
hidden_size, // m
max_num_global, // n
hidden_size, // k
&one, // alpha
global_q_weight, // A
hidden_size, // lda
0, // strideA
input_data, // B
hidden_size, // ldb
sequence_length * hidden_size, // strideB
&zero, // beta
global_q, // C
hidden_size, // ldc
max_num_global * hidden_size, // strideC
batch_size, // batch count
device_prop));
}
// global k
const HipT* global_k_weight = global_weights_data + hidden_size * hidden_size;
HipT* global_k = global_gemm_buffer;
ROCBLAS_RETURN_IF_ERROR(rocblasGemmHelper(
rocblas, rocblas_operation_none, rocblas_operation_none, n, m, k, &one,
global_k_weight, n,
input_data, k,
&zero, global_k, n, device_prop));
// global v
const HipT* global_v_weight = global_k_weight + hidden_size * hidden_size;
HipT* global_v = global_gemm_buffer + batch_size * sequence_length * hidden_size;
ROCBLAS_RETURN_IF_ERROR(rocblasGemmHelper(
rocblas, rocblas_operation_none, rocblas_operation_none, n, m, k, &one,
global_v_weight, n,
input_data, k,
&zero, global_v, n, device_prop));
}
}
size_t workSpaceSize = GetLongformerAttentionWorkspaceSize(element_size,
batch_size,
num_heads_,
head_size,
sequence_length,
max_num_global,
window_,
disable_compact_memory);
auto workspace_buffer = GetScratchBuffer<void>(workSpaceSize);
ORT_RETURN_IF_ERROR(LaunchLongformerAttentionKernel(
device_prop,
rocblas,
stream,
reinterpret_cast<const HipT*>(gemm_buffer.get()),
reinterpret_cast<const HipT*>(bias->Data<T>()),
reinterpret_cast<const HipT*>(attention_mask->Data<T>()),
reinterpret_cast<const HipT*>(global_gemm_buffer),
reinterpret_cast<const HipT*>(global_bias->Data<T>()),
global_attention_mask->Data<int>(),
global_index_buffer.get(),
batch_global_num_buffer.get(),
pinned_buffer.get(),
workspace_buffer.get(),
output->MutableData<T>(),
batch_size,
sequence_length,
num_heads_,
head_size,
window_,
max_num_global,
element_size,
disable_compact_memory,
use_merged_qkv_weights,
use_half4_));
// Defer release of pinned memory since hipStreamSynchronize is not used here and kernel need access the buffer.
this->AddDeferredReleaseCPUPtr(pinned_buffer.release());
return Status::OK();
}
} // namespace rocm
} // namespace contrib
} // namespace onnxruntime
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include "core/common/common.h"
#include "core/providers/rocm/rocm_kernel.h"
#include "contrib_ops/cpu/bert/longformer_attention_base.h"
namespace onnxruntime {
namespace contrib {
namespace rocm {
using namespace onnxruntime::rocm;
template <typename T>
class LongformerAttention final : public RocmKernel, public LongformerAttentionBase {
public:
LongformerAttention(const OpKernelInfo& info);
Status ComputeInternal(OpKernelContext* context) const override;
private:
bool use_compact_memory_;
bool use_half4_;
};
} // namespace rocm
} // namespace contrib
} // namespace onnxruntime
#include "hip/hip_runtime.h"
/*
Copyright (c) NVIDIA Corporation and Microsoft Corporation
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
// Limitations of current Longformer Attention ROCM Kernels:
// (1) Does not support global tokens in the middle. All global tokens shall be in the beginning of sequence.
// (2) Maximum number of global tokens <= one-sided attention window
#include <hipcub/hipcub.hpp>
#include <rocblas.h>
#include <hip/hip_fp16.h>
#include <hip/hip_runtime.h>
#include <limits>
#include "core/providers/rocm/cu_inc/common.cuh"
#include "core/providers/rocm/rocm_common.h"
#include "contrib_ops/rocm/bert/add_bias_transpose.h"
#include "contrib_ops/rocm/bert/attention_impl.h"
#include "contrib_ops/rocm/bert/longformer_attention_softmax.h"
#include "contrib_ops/rocm/bert/longformer_attention_impl.h"
using namespace onnxruntime::rocm;
using namespace hipcub;
#define CHECK(expr) ROCBLAS_RETURN_IF_ERROR(expr)
#define CHECK_ROCM(expr) HIP_RETURN_IF_ERROR(expr)
namespace onnxruntime {
namespace contrib {
namespace rocm {
// Denote: batch size (B), sequence length (S), number of heads (N), dimension per head (H), maximum global tokens (G)
//
// Workspace layout (default data type T is float or half):
// [SoftmaxSpace] [Q:BxNxSxH] [K:BxNxSxH] [V:BxNxSxH] [Global_Q:BxNxSxH] [Global_K:BxNxSxH] [Global_V:BxNxSxH]
// where Global_Q, Global_K and Global_V are optional. They are not allocated when there is no global token.
//
// SoftmaxSpace layout is the following when compact memory is enabled:
// [scratch1: (5S-3W)*W*N*B] [scratch2: size_t 15]
// Scratch1 has 5 buffers for local and global attention calculation.
// Scratch2 has 5 input/output pointers, 5 buffer sizes and 5 strides related to scratch1.
//
// SoftmaxSpace layout is the following When compact memory is disabled:
// [scratch1: BxNxSxS] [scratch2: BxNxSxS]
static size_t Align(size_t a) {
const size_t alignment = 128; // Align on a 16-byte boundary to avoid "misaligned address" error.
return CeilDiv(a, alignment) * alignment;
}
size_t GetScratch1Size(size_t element_size, size_t batch_size, size_t num_heads, size_t sequence_length, size_t window) {
size_t bytes = (5 * sequence_length - 3 * window) * window * num_heads * batch_size * element_size;
return Align(bytes);
}
constexpr size_t GetScratch2Size() {
return 5 * sizeof(void*) + 10 * sizeof(size_t);
}
size_t GetLongformerSoftmaxWorkspaceSize(
size_t element_size,
size_t batch_size,
size_t num_heads,
size_t sequence_length,
size_t window,
bool disable_compact_memory) {
if (!disable_compact_memory) {
size_t scratch1_size = GetScratch1Size(element_size, batch_size, num_heads, sequence_length, window);
size_t scratch2_size = GetScratch2Size();
return Align(scratch1_size + scratch2_size);
} else {
return 2 * GetAttentionScratchSize(element_size, batch_size, num_heads, sequence_length, sequence_length);
}
}
size_t GetLongformerAttentionWorkspaceSize(
size_t element_size,
size_t batch_size,
size_t num_heads,
size_t head_size,
size_t sequence_length,
size_t max_num_global,
size_t window,
bool disable_compact_memory) {
size_t softmax_size = GetLongformerSoftmaxWorkspaceSize(element_size,
batch_size,
num_heads,
sequence_length,
window,
disable_compact_memory);
size_t qkv_size = static_cast<size_t>(3) * batch_size * sequence_length * num_heads * head_size * element_size;
size_t global_qkv_size = max_num_global > 0 ? qkv_size : 0;
return softmax_size + qkv_size + global_qkv_size;
}
// Size of buffer of pinned memory in CPU. The buffer is used to copy memory between CPU and GPU.
// The buffer includes two parts: [global_count (copy of batch_global_num): int Bx1] [copy of scratch2]
size_t GetPinnedBufferSize(size_t batch_size) {
return sizeof(int) * batch_size + GetScratch2Size();
}
// Softmax kernel for compact format
template <typename T, int blockSize>
__launch_bounds__(blockSize)
__global__ void LongformerSoftmaxKernel(const int* global_attention,
const int* global_index,
const int* batch_global_num,
void* buffer_pointers,
const T* attention_mask,
float scaler,
int sequence_length,
int num_heads,
int window) {
typedef hipcub::BlockReduce<float, blockSize> BlockReduce;
__shared__ typename BlockReduce::TempStorage block_reduce_temp;
int tid = threadIdx.x;
const int batch_index = blockIdx.x / (sequence_length * num_heads);
const int row_index = blockIdx.x % sequence_length;
const int head_index = (blockIdx.x / sequence_length) % num_heads;
// Adjust the pointers for the batch
const T* mask_block = attention_mask + sequence_length * batch_index;
const int* global_index_block = global_index + sequence_length * batch_index;
const int global_num = batch_global_num[batch_index];
size_t* p_inputs = reinterpret_cast<size_t*>(buffer_pointers);
size_t* p_outputs = reinterpret_cast<size_t*>(buffer_pointers);
size_t* input_sizes = reinterpret_cast<size_t*>(buffer_pointers) + 5;
size_t* input_strides = reinterpret_cast<size_t*>(buffer_pointers) + 10;
const T* inputs[5];
T* outputs[5];
for (int i = 0; i < 5; ++i) {
inputs[i] = reinterpret_cast<T*>(p_inputs[i]) + batch_index * num_heads * input_sizes[i];
outputs[i] = reinterpret_cast<T*>(p_outputs[i]) + batch_index * num_heads * input_sizes[i];
}
// Local attention token
int col_start = 0;
int col_end = sequence_length;
bool is_local_row = (global_attention[batch_index * sequence_length + row_index] == static_cast<int>(0));
if (is_local_row) {
col_start = row_index - window;
if (col_start < 0) {
col_start = 0;
}
col_end = row_index + window + 1;
if (col_end > sequence_length) {
col_end = sequence_length;
}
}
// If mask is set then set everything to zero to match huggingface transformers implementation
if ((float)mask_block[row_index] != 0.f) {
if (is_local_row) {
T* output_block = nullptr;
T* output_global = nullptr;
int local_offset = row_index % window;
int local_start = 0;
int local_end = 3 * window;
if (row_index < window) {
local_start = 0;
local_end = 2 * window;
output_block = outputs[0] + row_index * input_strides[0] + head_index * input_sizes[0];
} else if (row_index < sequence_length - window) {
output_block = outputs[1] + (row_index - window) * input_strides[1] + head_index * input_sizes[1];
} else {
local_start = 0;
local_end = 2 * window;
output_block = outputs[2] + local_offset * input_strides[2] + head_index * input_sizes[2];
}
for (int i = local_start + tid; i < local_end; i += blockSize) {
output_block[i] = 0;
}
if ((row_index - 2 * window) >= 0) {
output_global = outputs[3] + (row_index - window) * input_strides[3] + head_index * input_sizes[3];
}
if (output_global != nullptr) {
for (int i = tid; i < global_num; i += blockSize) {
output_global[i] = 0;
}
}
} else {
T* output_block = outputs[4];
for (int i = tid; i < sequence_length; i += blockSize)
output_block[i] = 0;
}
return;
}
float sum_input = 0.;
__shared__ float sum_shared;
// Calculate max input
float max_input = -std::numeric_limits<float>::infinity();
__shared__ float max_shared;
if (is_local_row) {
const T* input_block = nullptr;
T* output_block = nullptr;
T* output_global = nullptr;
int local_offset = row_index % window;
int local_start = local_offset;
int local_end = local_start + 2 * window + 1;
int zero_start = 0;
int zero_end = 3 * window;
if (row_index < window) {
local_start = 0;
local_end = local_offset + window + 1;
zero_end = 2 * window;
input_block = inputs[0] + row_index * input_strides[0] + head_index * input_sizes[0];
output_block = outputs[0] + row_index * input_strides[0] + head_index * input_sizes[0];
} else if (row_index < sequence_length - window) {
input_block = inputs[1] + (row_index - window) * input_strides[1] + head_index * input_sizes[1];
output_block = outputs[1] + (row_index - window) * input_strides[1] + head_index * input_sizes[1];
} else {
local_start = local_offset;
local_end = 2 * window;
zero_end = 2 * window;
input_block = inputs[2] + local_offset * input_strides[2] + head_index * input_sizes[2];
output_block = outputs[2] + local_offset * input_strides[2] + head_index * input_sizes[2];
}
const T* input_global = nullptr;
int local_global = row_index - window;
if (local_global > global_num) {
local_global = global_num;
}
if (local_global > 0) {
input_global = inputs[3] + (row_index - window) * input_strides[3] + head_index * input_sizes[3];
}
if (row_index < window) {
output_global = (T*)outputs[0] + row_index * input_strides[0] + head_index * input_sizes[0];
} else if (row_index < 2 * window) {
output_global = outputs[1] + (row_index - window) * input_strides[1] + head_index * input_sizes[1];
} else {
output_global = outputs[3] + (row_index - window) * input_strides[3] + head_index * input_sizes[3];
}
for (int i = local_start + tid, j = col_start + tid; i < local_end; i += blockSize, j += blockSize) {
float x = input_block[i];
x = x * scaler + (float)mask_block[j];
if (max_input < x)
max_input = x;
}
if (input_global != nullptr) {
for (int i = tid; i < local_global; i += blockSize) {
float x = input_global[global_index_block[i]];
x = x * scaler + (float)mask_block[global_index_block[i]];
if (max_input < x)
max_input = x;
}
}
float max_block = BlockReduce(block_reduce_temp).Reduce(max_input, hipcub::Max());
if (tid == 0) {
max_shared = max_block;
}
__syncthreads();
for (int i = local_start + tid, j = col_start + tid; i < local_end; i += blockSize, j += blockSize) {
float x = input_block[i];
x = expf((x)*scaler + (float)mask_block[j] - max_shared);
sum_input += x;
}
if (input_global != nullptr) {
for (int i = tid, j = col_start + tid; i < local_global; i += blockSize, j += blockSize) {
float x = input_global[global_index_block[i]];
x = expf((x)*scaler + (float)mask_block[j] - max_shared);
sum_input += x;
}
}
float sum_block = BlockReduce(block_reduce_temp).Reduce(sum_input, hipcub::Sum());
if (tid == 0) {
sum_shared = sum_block;
}
__syncthreads();
float recip_sum = 1.f / sum_shared;
for (int i = tid + zero_start; i < local_start; i += blockSize) {
output_block[i] = (T)(0.);
}
for (int i = tid + local_end; i < zero_end; i += blockSize) {
output_block[i] = (T)(0.);
}
__syncthreads();
for (int i = local_start + tid, j = col_start + tid; i < local_end; i += blockSize, j += blockSize) {
float x = input_block[i];
x = expf((x)*scaler + (float)mask_block[j] - max_shared);
output_block[i] = (T)(recip_sum * x);
}
if (input_global != nullptr) {
for (int i = tid; i < local_global; i += blockSize) {
float x = input_global[global_index_block[i]];
x = expf((x)*scaler + (float)mask_block[global_index_block[i]] - max_shared);
output_global[i] = (T)(recip_sum * x);
}
}
} else {
// Global tokens
const T* input_block = inputs[4] + row_index * input_strides[4] + head_index * input_sizes[4];
T* output_block = outputs[4] + row_index * input_strides[4] + head_index * input_sizes[4];
for (int i = tid; i < sequence_length; i += blockSize) {
float x = input_block[i];
x = x * scaler + (float)mask_block[i];
if (max_input < x)
max_input = x;
}
float max_block = BlockReduce(block_reduce_temp).Reduce(max_input, hipcub::Max());
if (tid == 0) {
max_shared = max_block;
}
__syncthreads();
for (int i = tid; i < sequence_length; i += blockSize) {
float x = input_block[i];
x = expf((x)*scaler + (float)mask_block[i] - max_shared);
sum_input += x;
}
float sum_block = BlockReduce(block_reduce_temp).Reduce(sum_input, hipcub::Sum());
if (tid == 0) {
sum_shared = sum_block;
}
__syncthreads();
float recip_sum = 1.f / sum_shared;
for (int i = tid; i < sequence_length; i += blockSize) {
float x = input_block[i];
x = expf((x)*scaler + (float)mask_block[i] - max_shared);
output_block[i] = (T)(recip_sum * x);
}
}
}
Status LaunchLongformerSoftmaxKernel(
hipStream_t stream,
rocblas_handle rocblas,
void* workspace,
const void* q, // transposed Q with shape (B, N, S, H)
const void* k, // transposed K with shape (B, N, S, H)
const void* v, // transposed V with shape (B, N, S, H)
const void* attention_mask, // attention mask with shape (B, S), with value 0 not masked and -10000 masked.
int max_num_global, // maximum number of global tokens (G)
const bool compact_global_q, // whether global_q has shape (B, N, G, H) instead of (B, N, S, H)
const void* global_q, // Q for global tokens with shape (B, N, S, H).
const void* global_k, // K for global tokens with shape (B, N, S, H)
const void* global_v, // V for global tokens with shape (B, N, S, H)
const int* global_attention, // global attention flags with shape (B, S), with value 0 for local and 1 for global.
const int* global_index, // Global index with shape (B, S)
const int* batch_global_num, // Number of global tokens per batch with shape (B, 1)
void* pinned_buffer, // Pinned memory in CPU with 2 parts: global tokens per batch, and data for scratch2
void* output, // output with shape (B, N, S, H)
float scaler, // scalar
int batch_size, // batch size
int sequence_length, // sequence length
int num_heads, // number of heads
int head_size, // hidden size per head
int window, // one sided window size
size_t element_size) { // size of element: 2 for half, and 4 for float
const int* global_count = reinterpret_cast<const int*>(pinned_buffer);
bool is_fp16 = (element_size == 2);
char* scratch1 = reinterpret_cast<char*>(workspace);
char* scratch2 = scratch1 + GetScratch1Size(element_size, batch_size, num_heads, sequence_length, window);
// Setup shared parameters for two strided batched matrix multiplies
rocblas_datatype Atype;
rocblas_datatype Btype;
rocblas_datatype Ctype;
rocblas_datatype resultType;
rocblas_gemm_algo algo = rocblas_gemm_algo_standard;
__half one_fp16, zero_fp16;
float one_fp32, zero_fp32;
void *alpha, *beta_0, *beta_1;
if (is_fp16) {
one_fp16 = __float2half(1.f);
zero_fp16 = __float2half(0.f);
alpha = static_cast<void*>(&one_fp16);
beta_0 = static_cast<void*>(&zero_fp16);
beta_1 = static_cast<void*>(&one_fp16);
Atype = rocblas_datatype_f16_r;
Btype = rocblas_datatype_f16_r;
Ctype = rocblas_datatype_f16_r;
resultType = rocblas_datatype_f16_r;
algo = rocblas_gemm_algo_standard;
} else {
one_fp32 = 1.f;
zero_fp32 = 0.f;
alpha = static_cast<void*>(&one_fp32);
beta_0 = static_cast<void*>(&zero_fp32);
beta_1 = static_cast<void*>(&one_fp32);
Atype = rocblas_datatype_f32_r;
Btype = rocblas_datatype_f32_r;
Ctype = rocblas_datatype_f32_r;
resultType = rocblas_datatype_f32_r;
}
// Strided batch matrix multiply
// qk = q * k^T
// Shapes: q and k = B x N x S x H, qk = B x N x S x S
// Convert col-major to row-major by swapping q and k in Gemm
size_t elements_per_batch = num_heads * sequence_length * head_size;
int stride_per_head = sequence_length * head_size; // stride for Q, K, V and output
// Local attention part
// S x S is calculated using sliding block WxW (W is one sided window size) like the following:
// [W][W]
// [W][W][W]
// [W][W][W]
// [W][W]
// The first and last rows have 2 blocks per row, and the remaining has 3 blocks per row.
// The calculation are splited into 3 parts: the first row, middle rows and finally the last row.
// To save space, we do not store the whole matrix. Instead, we only allocate space for these blocks.
//
// For global attention part, we have two assumptions:
// (1) Global tokens are at the beginging of sequence
// (2) Number of global tokens <= attention window
//
// The results are stored in scratch1 buffer:
// Number of elements for local attention are (3*S/W-2)*W*W*N*B, or (3S-2W)*W*N*B
// Number of elements for local attends to global are (S-W)*W*N*B
// Number of elements for global attends to everything are S*W*N*B
// Total elements (FP16 or FP32) are (5S-3W)*W*N*B
const int w = window;
const int middle_count = (sequence_length - 2 * w) / w;
int last_block = (sequence_length / w) - 1;
// Determine the non-zero block dimensions and pointers
// Buffer size per head for a single batch
size_t buffer_sizes[5] = {
static_cast<size_t>(w * w * 2), // first row of blocks has 2 WxW blocks
static_cast<size_t>(w * w * middle_count * 3), // middle rows of blocks have 3 WxW blocks per row
static_cast<size_t>(w * w * 2), // last row of blocks has 2 WxW blocks
static_cast<size_t>(w * (sequence_length - w)), // local attends to global: global tokens <= window size
static_cast<size_t>(w * sequence_length)}; // global attends to everything.
size_t buffer_strides[5] = {
static_cast<size_t>(w * 2),
static_cast<size_t>(w * 3),
static_cast<size_t>(w * 2),
static_cast<size_t>(w), // number of global tokens <= window size
static_cast<size_t>(sequence_length)};
void* buffer_pointers[5];
char* current_pointer = scratch1;
for (int i = 0; i < 5; ++i) {
buffer_pointers[i] = reinterpret_cast<void*>(current_pointer);
current_pointer += buffer_sizes[i] * num_heads * batch_size * element_size;
}
// Copy to a continues buffer first so that we only need call hipMemcpyAsync once
char* temp_buffer = reinterpret_cast<char*>(pinned_buffer) + sizeof(int) * batch_size;
memcpy(temp_buffer, &buffer_pointers[0], 5 * sizeof(void*));
memcpy(temp_buffer + 5 * sizeof(void*), &buffer_sizes[0], 5 * sizeof(size_t));
memcpy(temp_buffer + 5 * sizeof(void*) + 5 * sizeof(size_t), &buffer_strides[0], 5 * sizeof(size_t));
CHECK_ROCM(hipMemcpyAsync(scratch2, temp_buffer, GetScratch2Size(), hipMemcpyHostToDevice, stream));
// Local attention part
{
// local attention per head - head
CHECK(_compat_rocblas_gemm_strided_batched_ex(rocblas,
rocblas_operation_transpose,
rocblas_operation_none,
2 * w, // m
w, // n
head_size, // k
alpha, // alpha
k, // A
Atype, // A type
head_size, // lda
stride_per_head, // strideA
q, // B
Btype, // B type
head_size, // ldb
stride_per_head, // strideB
beta_0, // beta
buffer_pointers[0], // C
Ctype, // C type
2 * w, // ldc
buffer_sizes[0], // strideC
batch_size * num_heads, // batch count
resultType,
algo));
// local attention per head - middle
if (middle_count > 0) {
for (int i = 0; i < batch_size; ++i) {
for (int j = 0; j < num_heads; ++j) {
const void* q_head = reinterpret_cast<const char*>(q) +
(i * elements_per_batch + (j * sequence_length + w) * head_size) * element_size;
const void* k_head = reinterpret_cast<const char*>(k) +
(i * elements_per_batch + j * sequence_length * head_size) * element_size;
void* qk_head = reinterpret_cast<char*>(buffer_pointers[1]) +
static_cast<size_t>(i * num_heads + j) * buffer_sizes[1] * element_size;
CHECK(_compat_rocblas_gemm_strided_batched_ex(rocblas,
rocblas_operation_transpose,
rocblas_operation_none,
3 * w, // m
w, // n
head_size, // k
alpha, // alpha
k_head, // A
Atype, // A type
head_size, // lda
w * head_size, // strideA
q_head, // B
Btype, // B type
head_size, // ldb
w * head_size, // strideB
beta_0, // beta
qk_head, // C
Ctype, // C type
3 * w, // ldc
3 * w * w, // strideC
middle_count, // batch count
resultType,
algo));
}
}
}
// local attention per head - tail
const void* q_head = reinterpret_cast<const char*>(q) + (last_block * w * head_size) * element_size;
const void* k_head = reinterpret_cast<const char*>(k) + ((last_block - 1) * w * head_size) * element_size;
CHECK(_compat_rocblas_gemm_strided_batched_ex(rocblas,
rocblas_operation_transpose,
rocblas_operation_none,
2 * w, // m
w, // n
head_size, // k
alpha, // alpha
k_head, // A
Atype, // A type
head_size, // lda
stride_per_head, // strideA
q_head, // B
Btype, // B type
head_size, // ldb
stride_per_head, // strideB
beta_0, // beta
buffer_pointers[2], // C
Ctype, // C type
2 * w, // ldc
buffer_sizes[2], // strideC
batch_size * num_heads, // batch count
resultType,
algo));
}
// Global attention part
for (int i = 0; i < batch_size; ++i) {
if (global_count[i] > 0) {
const void* q_batch = reinterpret_cast<const char*>(q) + (i * elements_per_batch + w * head_size) * element_size;
const void* k_batch = reinterpret_cast<const char*>(k) + (i * elements_per_batch) * element_size;
void* qk_batch = reinterpret_cast<char*>(buffer_pointers[3]) + (i * buffer_sizes[3]) * num_heads * element_size;
// Local tokens attending global tokens
CHECK(_compat_rocblas_gemm_strided_batched_ex(rocblas,
rocblas_operation_transpose,
rocblas_operation_none,
global_count[i], // m
sequence_length - w, // n
head_size, // k
alpha, // alpha
k_batch, // A
Atype, // A type
head_size, // lda
stride_per_head, // strideA
q_batch, // B
Btype, // B type
head_size, // ldb
stride_per_head, // strideB
beta_0, // beta
qk_batch, // C
Ctype, // C type
w, // ldc
buffer_sizes[3], // strideC
num_heads, // batch count
resultType,
algo));
const size_t global_q_per_batch = compact_global_q ? num_heads * max_num_global * head_size : elements_per_batch;
const int global_q_stride = (compact_global_q ? max_num_global * head_size : stride_per_head);
const void* global_q_batch = reinterpret_cast<const char*>(global_q) + (i * global_q_per_batch) * element_size;
const void* global_k_batch = reinterpret_cast<const char*>(global_k) + (i * elements_per_batch) * element_size;
qk_batch = reinterpret_cast<char*>(buffer_pointers[4]) + (i * buffer_sizes[4] * num_heads) * element_size;
// Global tokens attending everything
// This GEMMs need to be last to make sure all global token entries are re-written.
CHECK(_compat_rocblas_gemm_strided_batched_ex(rocblas,
rocblas_operation_transpose,
rocblas_operation_none,
sequence_length, // m
global_count[i], // n
head_size, // k
alpha, // alpha
global_k_batch, // A
Atype, // A type
head_size, // lda
stride_per_head, // strideA
global_q_batch, // B
Btype, // B type
head_size, // ldb
global_q_stride, // strideB.
beta_0, // beta
qk_batch, // C
Ctype, // C type
sequence_length, // ldc
buffer_sizes[4], // strideC
num_heads, // batch count
resultType,
algo));
}
}
const int blockSize = 64;
const int gridSize = batch_size * num_heads * sequence_length;
if (is_fp16) {
hipLaunchKernelGGL(HIP_KERNEL_NAME(LongformerSoftmaxKernel<__half, blockSize>), gridSize, blockSize, 0, stream,
global_attention,
global_index,
batch_global_num,
scratch2,
static_cast<const __half*>(attention_mask),
scaler, sequence_length, num_heads, window);
} else {
hipLaunchKernelGGL(HIP_KERNEL_NAME(LongformerSoftmaxKernel<float, blockSize>), gridSize, blockSize, 0, stream,
global_attention,
global_index,
batch_global_num,
scratch2,
static_cast<const float*>(attention_mask),
scaler, sequence_length, num_heads, window);
}
// local values attending the softmax score.
{
// local attention per head - head
CHECK(_compat_rocblas_gemm_strided_batched_ex(rocblas,
rocblas_operation_none,
rocblas_operation_none,
head_size, // m
w, // n
2 * w, // k
alpha, // alpha
v, // A
Atype, // A type
head_size, // lda
stride_per_head, // strideA
buffer_pointers[0], // B
Btype, // B type
static_cast<int>(buffer_strides[0]), // ldb
buffer_sizes[0], // strideB
beta_0, // beta
output, // C
Ctype, // C type
head_size, // ldc
stride_per_head, // strideC
batch_size * num_heads, // batch count
resultType,
algo));
// local attention per head - middle
if (middle_count > 0) {
for (int i = 0; i < batch_size; ++i) {
for (int j = 0; j < num_heads; ++j) {
const void* v_head = reinterpret_cast<const char*>(v) +
(i * elements_per_batch + j * head_size * sequence_length) * element_size;
const void* prob_head = reinterpret_cast<const char*>(buffer_pointers[1]) +
(i * num_heads + j) * buffer_sizes[1] * element_size;
void* out_head = reinterpret_cast<char*>(output) +
(i * elements_per_batch + j * head_size * sequence_length + w * head_size) * element_size;
CHECK(_compat_rocblas_gemm_strided_batched_ex(rocblas,
rocblas_operation_none,
rocblas_operation_none,
head_size, // m
w, // n
3 * w, // k
alpha, // alpha
v_head, // A
Atype, // A type
head_size, // lda
w * head_size, // strideA
prob_head, // B
Btype, // B type
static_cast<int>(buffer_strides[1]), // ldb
3 * w * w, // strideB
beta_0, // beta
out_head, // C
Ctype, // C type
head_size, // ldc
w * head_size, // strideC
middle_count, // batch count
resultType,
algo));
}
}
}
// local attention per head - tail
const void* v_head = reinterpret_cast<const char*>(v) + (last_block - 1) * w * head_size * element_size;
void* out_head = reinterpret_cast<char*>(output) + last_block * w * head_size * element_size;
CHECK(_compat_rocblas_gemm_strided_batched_ex(rocblas,
rocblas_operation_none,
rocblas_operation_none,
head_size, // m
w, // n
2 * w, // k
alpha, // alpha
v_head, // A
Atype, // A type
head_size, // lda
stride_per_head, // strideA
buffer_pointers[2], // B
Btype, // B type
static_cast<int>(buffer_strides[2]), // ldb
buffer_sizes[2], // strideB
beta_0, // beta
out_head, // C
Ctype, // C type
head_size, // ldc
stride_per_head, // strideC
batch_size * num_heads, // batch count
resultType,
algo));
}
// global attention part
for (int i = 0; i < batch_size; ++i) {
if (global_count[i] > 0) {
// Local tokens attending global tokens
const void* v_head = reinterpret_cast<const char*>(v) + (i * elements_per_batch) * element_size;
const void* prob_head = reinterpret_cast<const char*>(buffer_pointers[3]) +
(i * buffer_sizes[3] * num_heads + w * buffer_strides[3]) * element_size;
void* out_head = reinterpret_cast<char*>(output) + (i * elements_per_batch + 2 * w * head_size) * element_size;
CHECK(_compat_rocblas_gemm_strided_batched_ex(rocblas,
rocblas_operation_none,
rocblas_operation_none,
head_size, // m
sequence_length - 2 * w, // n
global_count[i], // k
alpha, // alpha
v_head, // A
Atype, // A type
head_size, // lda
stride_per_head, // strideA
prob_head, // B
Btype, // B type
static_cast<int>(buffer_strides[3]), // ldb
buffer_sizes[3], // strideB
beta_1, // beta
out_head, // C
Ctype, // C type
head_size, // ldc
stride_per_head, // strideC
num_heads, // batch count
resultType,
algo));
// Global tokens attending everything
v_head = reinterpret_cast<const char*>(global_v) + (i * elements_per_batch) * element_size;
prob_head = reinterpret_cast<const char*>(buffer_pointers[4]) + (i * buffer_sizes[4] * num_heads) * element_size;
out_head = reinterpret_cast<char*>(output) + (i * elements_per_batch) * element_size;
CHECK(_compat_rocblas_gemm_strided_batched_ex(rocblas,
rocblas_operation_none,
rocblas_operation_none,
head_size, // m
global_count[i], // n
sequence_length, // k: re-write entries completely
alpha, // alpha
v_head, // A
Atype, // A type
head_size, // lda
stride_per_head, // strideA
prob_head, // B
Btype, // B type
static_cast<int>(buffer_strides[4]), // ldb
buffer_sizes[4], // strideB
beta_0, // beta: overwrite
out_head, // C: assumes global tokens at the beginning of sequence
Ctype, // C type
head_size, // ldc
stride_per_head, // strideC
num_heads, // batch count
resultType,
algo));
}
}
return Status::OK();
}
template <typename T>
Status LongformerQkvToContext(
const hipDeviceProp_t& device_prop,
rocblas_handle rocblas,
hipStream_t stream,
const int batch_size, // batch size
const int sequence_length, // sequence length
const int num_heads, // number of attention heads
const int head_size, // hidden size per head
const int window, // Half (one-sided) window size
const size_t element_size,
const T* input, // input for transpose
const T* bias, // bias to add to transposed input
const T* attention_mask, // attention mask with shape (B, S), with value 0.0 not masked, and -10000.0 masked.
const T* global_input, // global input for transpose
const T* global_bias, // bias to add to transposed global input
const int* global_attention, // global attention flags with shape (B, S), with value 0 for local and 1 for global.
const int* global_index, // Global index with shape (B, S)
const int* batch_global_num, // Number of global tokens per batch with shape (B, 1)
const int max_num_global, // Maximum number of global tokens (G)
void* pinned_buffer, // Pinned memory in CPU. Number of global tokens per batch with shape (B, 1)
T* workspace, // Softmax space
T* output, // output
size_t softmax_workspace_size,
bool disable_compact_memory,
bool use_merged_qkv_weights,
bool use_half4) {
T* qkv = reinterpret_cast<T*>(reinterpret_cast<char*>(workspace) + softmax_workspace_size);
// Number of elements in Q, K, V, Global_Q, Global_K or Global_V are same: BxNxSxH
const int elements = batch_size * num_heads * sequence_length * head_size;
const int max_threads_per_block(device_prop.maxThreadsPerBlock);
const int format = static_cast<int>(use_merged_qkv_weights);
bool compact_global_q = false;
// The order of qkv space:
// Q, K, V, Global_K, Global_V, Global_Q (format 0)
// Q, K, V, Global_Q, Global_K, Global_V (format 1)
// Assume Q, K and V has same hidden size
if (format == 1 || max_num_global == 0 || nullptr == global_input) {
if (bias == nullptr) {
ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 3, sequence_length, batch_size, head_size, num_heads,
max_threads_per_block, false, input, qkv));
} else {
LaunchAddBiasTranspose(stream, 3, format, max_threads_per_block, batch_size,
sequence_length, num_heads, head_size,
input, bias, qkv,
use_half4, head_size);
}
if (max_num_global > 0 && nullptr != global_input) {
if (global_bias == nullptr) {
ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 3, sequence_length, batch_size, head_size, num_heads,
max_threads_per_block, false, global_input, qkv + 3 * elements));
} else {
LaunchAddBiasTranspose(stream, 3, format, max_threads_per_block, batch_size,
sequence_length, num_heads, head_size,
global_input, global_bias, qkv + 3 * elements,
use_half4, head_size);
}
}
} else {
LaunchAddBiasTranspose(stream, 5, format, max_threads_per_block, batch_size,
sequence_length, num_heads, head_size,
input, bias, qkv,
use_half4, head_size);
compact_global_q = (disable_compact_memory == false);
LaunchAddBiasTranspose(stream, 1, format, max_threads_per_block, batch_size,
compact_global_q ? max_num_global : sequence_length, num_heads, head_size,
global_input + 2 * elements, global_bias, qkv + 5 * elements,
use_half4, head_size);
}
HIP_RETURN_IF_ERROR(hipGetLastError());
// Transposed Q, K, V with shape (B, N, S, H)
const T* q = qkv;
const T* k = q + elements;
const T* v = k + elements;
// Transposed global Q, K, V with shape (B, N, S, H).
// When compact_global_q is true, Global Q has actual shape (B, N, G, H) although we allocated space of (B, N, S, H)
// When max_num_global == 0, these pointers are not used in GEMM so the value does not matter.
const T* global_q = (format == 1 ? v + elements : qkv + 5 * elements);
const T* global_k = (format == 1 ? global_q + elements : qkv + 3 * elements);
const T* global_v = (format == 1 ? global_k + elements : qkv + 4 * elements);
// Q*K' are scaled by 1/sqrt(H)
const float rsqrt_head_size = 1.f / sqrt(static_cast<float>(head_size));
T* temp_output = qkv; // Q will be overwritten
if (disable_compact_memory) {
ORT_RETURN_IF_ERROR(LaunchLongformerSoftmaxSimpleKernel(
stream,
rocblas,
workspace,
q,
k,
v,
attention_mask,
global_q,
global_k,
global_v,
global_attention,
global_index,
batch_global_num,
pinned_buffer,
temp_output,
rsqrt_head_size,
batch_size,
sequence_length,
num_heads,
head_size,
window,
element_size));
} else {
ORT_ENFORCE(max_num_global <= window);
ORT_RETURN_IF_ERROR(LaunchLongformerSoftmaxKernel(
stream,
rocblas,
workspace,
q,
k,
v,
attention_mask,
max_num_global,
compact_global_q,
global_q,
global_k,
global_v,
global_attention,
global_index,
batch_global_num,
pinned_buffer,
temp_output,
rsqrt_head_size,
batch_size,
sequence_length,
num_heads,
head_size,
window,
element_size));
}
// The temp_output is BxNxSxH, transpose it to final output BxSxNxH
return LaunchTransCtx(stream, sequence_length, batch_size, head_size,
num_heads, max_threads_per_block, false, temp_output, output);
}
Status LaunchLongformerAttentionKernel(
const hipDeviceProp_t& device_prop,
rocblas_handle rocblas,
hipStream_t stream,
const void* input,
const void* bias,
const void* attention_mask,
const void* global_input,
const void* global_bias,
const int* global_attention,
const int* global_index,
const int* batch_global_num,
void* pinned_buffer,
void* workspace,
void* output,
int batch_size,
int sequence_length,
int num_heads,
int head_size,
int window,
int max_num_global,
const size_t element_size,
bool disable_compact_memory,
bool use_merged_qkv_weights,
bool use_half4) {
CompatRocblasMathModeSetter helper(device_prop, rocblas, 0 /* CUBLAS_TENSOR_OP_MATH is deprecated */);
size_t softmax_workspace_size = GetLongformerSoftmaxWorkspaceSize(element_size,
batch_size,
num_heads,
sequence_length,
window,
disable_compact_memory);
if (element_size == 2) {
return LongformerQkvToContext(device_prop, rocblas, stream,
batch_size, sequence_length, num_heads, head_size, window, element_size,
reinterpret_cast<const half*>(input),
reinterpret_cast<const half*>(bias),
reinterpret_cast<const half*>(attention_mask),
reinterpret_cast<const half*>(global_input),
reinterpret_cast<const half*>(global_bias),
global_attention,
global_index,
batch_global_num,
max_num_global,
pinned_buffer,
reinterpret_cast<half*>(workspace),
reinterpret_cast<half*>(output),
softmax_workspace_size,
disable_compact_memory,
use_merged_qkv_weights,
use_half4);
} else {
return LongformerQkvToContext(device_prop, rocblas, stream,
batch_size, sequence_length, num_heads, head_size, window, element_size,
reinterpret_cast<const float*>(input),
reinterpret_cast<const float*>(bias),
reinterpret_cast<const float*>(attention_mask),
reinterpret_cast<const float*>(global_input),
reinterpret_cast<const float*>(global_bias),
global_attention,
global_index,
batch_global_num,
max_num_global,
pinned_buffer,
reinterpret_cast<float*>(workspace),
reinterpret_cast<float*>(output),
softmax_workspace_size,
disable_compact_memory,
use_merged_qkv_weights,
false);
}
}
} // namespace rocm
} // namespace contrib
} // namespace onnxruntime
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include "core/providers/rocm/shared_inc/rocm_utils.h"
namespace onnxruntime {
namespace contrib {
namespace rocm {
size_t GetPinnedBufferSize(
size_t batch_size);
size_t GetLongformerAttentionWorkspaceSize(
size_t element_size,
size_t batch_size,
size_t num_heads,
size_t head_size,
size_t sequence_length,
size_t max_num_global,
size_t window,
bool disable_compact_memory);
Status LaunchLongformerAttentionKernel(
const hipDeviceProp_t& device_prop, // Device Properties
rocblas_handle rocblas, // Rocblas handle
hipStream_t stream, // ROCM stream
const void* input, // Input tensor
const void* bias, // Bias tensor
const void* attention_mask, // Attention mask with shape (B, S)
const void* global_input, // Global attention input, or nullptr when max_num_global == 0.
const void* global_bias, // Global bias tensor
const int* global_attention, // Global attention flags with shape (B, S)
const int* global_index, // Global index
const int* batch_global_num, // Number of global tokens per batch. It is in device memory.
void* pinned_buffer, // Pinned memory: copy of batch_global_num, and a buffer to copy to scratch2.
void* workspace, // Temporary buffer
void* output, // Output tensor
int batch_size, // Batch size (B)
int sequence_length, // Sequence length (S)
int num_heads, // Number of attention heads (N)
int head_size, // Hidden layer size per head (H)
int window, // One sided attention window (W)
int max_num_global, // Maximum number of global tokens (G)
const size_t element_size, // Element size of input tensor,
bool disable_compact_memory, // Disable compact memory kernel
bool use_merged_qkv_weights,
bool use_half4);
} // namespace rocm
} // namespace contrib
} // namespace onnxruntime
#include "hip/hip_runtime.h"
/*
Copyright (c) NVIDIA Corporation and Microsoft Corporation
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
// This is rocm kernels for longformer attention softmax that does not use compact memory.
// It uses two temporary matrix of BxNxSxS, and consumes more memory when sequence length is large.
// Its logic is simpler with less constraints (like number of global tokens could be larger than attention windows).
#include <hipcub/hipcub.hpp>
#include <rocblas.h>
#include <hip/hip_fp16.h>
#include <hip/hip_runtime.h>
#include <limits>
#include "core/providers/rocm/cu_inc/common.cuh"
#include "core/providers/rocm/rocm_common.h"
#include "contrib_ops/rocm/bert/longformer_attention_softmax.h"
#include "contrib_ops/rocm/bert/attention_impl.h"
using namespace onnxruntime::rocm;
using namespace hipcub;
#define CHECK(expr) ROCBLAS_RETURN_IF_ERROR(expr)
namespace onnxruntime {
namespace contrib {
namespace rocm {
template <typename T, int blockSize>
__launch_bounds__(blockSize)
__global__ void LongformerSoftmaxSimpleKernel(const int* global_attention,
const int* global_index,
const int* batch_global_num,
const T* input,
const T* attention_mask,
T* output,
float scaler,
int dim0,
int sequence_length,
int attention_window) {
typedef hipcub::BlockReduce<float, blockSize> BlockReduce;
__shared__ typename BlockReduce::TempStorage block_reduce_temp;
__shared__ float max_shared;
__shared__ float sum_shared;
const T* input_block = input + sequence_length * blockIdx.x;
T* output_block = output + sequence_length * blockIdx.x;
const int batch_index = blockIdx.x / dim0;
const int row_index = blockIdx.x % sequence_length;
const int global_num = batch_global_num[batch_index];
// To be consistent with Huggingface Longformer, the row of maksed word are set as zero.
if ((float)attention_mask[batch_index * sequence_length + row_index] < 0.0f) {
for (int i = threadIdx.x; i < sequence_length; i += blockSize) {
output_block[i] = (T)(0);
}
return;
}
// local attention token
int col_start = 0;
int col_end = sequence_length;
bool is_local_row = (global_attention[batch_index * sequence_length + row_index] == (int)0);
if (is_local_row) {
col_start = row_index - attention_window;
if (col_start < 0) {
col_start = 0;
}
col_end = row_index + attention_window + 1;
if (col_end > sequence_length) {
col_end = sequence_length;
}
}
const T* mask_block = attention_mask + sequence_length * batch_index;
int tid = threadIdx.x;
// calculate max input
float max_input = -std::numeric_limits<float>::infinity();
// #pragma unroll 16
for (int i = tid + col_start; i < col_end; i += blockSize) {
float x = input_block[i];
x = x * scaler + (float)mask_block[i];
if (max_input < x) {
max_input = x;
}
}
if (is_local_row) {
for (int g = tid; g < global_num; g += blockSize) {
int i = global_index[g];
if (i < col_start || i >= col_end) {
float x = input_block[i];
x = x * scaler + (float)mask_block[i];
if (max_input < x) {
max_input = x;
}
}
}
}
float max_block = BlockReduce(block_reduce_temp).Reduce(max_input, hipcub::Max());
if (tid == 0) {
max_shared = max_block;
}
__syncthreads();
float sum_input = 0.f;
// #pragma unroll 16
for (int i = tid + col_start; i < col_end; i += blockSize) {
float x = input_block[i];
x = expf((x)*scaler + (float)mask_block[i] - max_shared);
sum_input += x;
}
if (is_local_row) {
for (int g = tid; g < global_num; g += blockSize) {
int i = global_index[g];
if (i < col_start || i >= col_end) {
float x = input_block[i];
x = expf((x)*scaler + (float)mask_block[i] - max_shared);
sum_input += x;
}
}
}
float sum_block = BlockReduce(block_reduce_temp).Reduce(sum_input, hipcub::Sum());
if (tid == 0) {
sum_shared = sum_block;
}
__syncthreads();
float recip_sum = 1.f / sum_shared;
if (is_local_row) {
// We only need to fill in zeros for blocks that will be used in the matrix multiplication
// following the Softmax.
//
// For now zero-out only [row_index - 2*attention_window, row_index + 2*attention_window],
// we can even be more agressive and reduce the zeroing out window size since
// each row has entries in 3 blocks (3*attention_window size instead of 4*attention_window)
int zero_start = row_index - 2 * attention_window;
if (zero_start < 0) {
zero_start = 0;
}
int zero_end = row_index + 2 * attention_window;
if (zero_end > sequence_length) {
zero_end = sequence_length;
}
for (int i = tid + zero_start; i < zero_end; i += blockSize) {
if (i < col_start || i >= col_end) {
output_block[i] = (T)(0.);
}
}
}
__syncthreads();
if (is_local_row) {
for (int g = tid; g < global_num; g += blockSize) {
int i = global_index[g];
if (i < col_start || i >= col_end) {
float x = input_block[i];
x = expf((x)*scaler + (float)mask_block[i] - max_shared);
output_block[i] = (T)(recip_sum * x);
}
}
}
// #pragma unroll 16
for (int i = tid + col_start; i < col_end; i += blockSize) {
float x = input_block[i];
x = expf((x)*scaler + (float)mask_block[i] - max_shared);
output_block[i] = (T)(recip_sum * x);
}
}
// Launch the softmax kernel for non compact memory.
Status LaunchLongformerSoftmaxSimpleKernel(
hipStream_t stream,
rocblas_handle rocblas,
void* workspace, // softmax space
const void* q, // transposed Q with shape (B, N, S, H)
const void* k, // transposed K with shape (B, N, S, H)
const void* v, // transposed V with shape (B, N, S, H)
const void* attention_mask, // attention mask with shape (B, S), with value 0.0 not masked, and -10000.0 masked.
const void* global_q, // Q for global tokens with shape (B, N, S, H)
const void* global_k, // K for global tokens with shape (B, N, S, H)
const void* global_v, // V for global tokens with shape (B, N, S, H)
const int* global_attention, // global attention flags with shape (B, S), with value 0 for local and 1 for global.
const int* global_index, // Global index with shape (B, S)
const int* batch_global_num, // Number of global tokens per batch with shape (B, 1)
void* pinned_buffer, // Pinned memory in CPU. Number of global tokens per batch with shape (B, 1)
void* output, // output with shape (B, N, S, H)
float scaler, // scalar
int batch_size, // batch size
int sequence_length, // sequence length
int num_heads, // number of heads
int head_size, // hidden size per head
int attention_window, // one sided windows size
size_t element_size) { // size of element: 2 for half, and 4 for float
bool is_fp16 = (element_size == 2);
void* scratch1 = reinterpret_cast<char*>(workspace);
size_t scratch1_size = GetAttentionScratchSize(element_size, batch_size, num_heads, sequence_length, sequence_length);
void* scratch2 = reinterpret_cast<char*>(scratch1) + scratch1_size;
// setup shared parameters for two strided batched matrix multiplies
rocblas_datatype Atype;
rocblas_datatype Btype;
rocblas_datatype Ctype;
rocblas_datatype resultType;
rocblas_gemm_algo algo = rocblas_gemm_algo_standard;
__half one_fp16, zero_fp16;
float one_fp32, zero_fp32;
void *alpha, *beta_0, *beta_1;
if (is_fp16) {
one_fp16 = __float2half(1.f);
zero_fp16 = __float2half(0.f);
alpha = static_cast<void*>(&one_fp16);
beta_0 = static_cast<void*>(&zero_fp16);
beta_1 = static_cast<void*>(&one_fp16);
Atype = rocblas_datatype_f16_r;
Btype = rocblas_datatype_f16_r;
Ctype = rocblas_datatype_f16_r;
resultType = rocblas_datatype_f16_r;
algo = rocblas_gemm_algo_standard;
} else {
one_fp32 = 1.f;
zero_fp32 = 0.f;
alpha = static_cast<void*>(&one_fp32);
beta_0 = static_cast<void*>(&zero_fp32);
beta_1 = static_cast<void*>(&one_fp32);
Atype = rocblas_datatype_f32_r;
Btype = rocblas_datatype_f32_r;
Ctype = rocblas_datatype_f32_r;
resultType = rocblas_datatype_f32_r;
}
// Strided batch matrix multiply
// qk = q * k^T
// Shapes: q and k = B x N x S x H, qk = B x N x S x S
// Convert col-major to row-major by swapping q and k in Gemm
// Local attention part
// S x S is calculated using sliding block WxW (W is one sided window size) like the following:
// [W][W]
// [W][W][W]
// [W][W][W]
// [W][W]
// The first and last rows have 2 blocks, and the remaining has 3 blocks per row.
// The calculation are splited into 3 parts: Fill the middle rows, then the first row and finally the last row.
// The results are stored in scratch1.
int w = attention_window;
size_t x_offset = static_cast<size_t>(num_heads) * sequence_length * head_size;
// Use size_t to avoid integer overflow since B x N x S x S is 12G for B=64, N=12, S=4096
size_t y_offset = static_cast<size_t>(num_heads) * sequence_length * sequence_length;
int last_block = (sequence_length / w) - 1;
int strideA = sequence_length * head_size;
int strideB = sequence_length * head_size;
int strideC = sequence_length * sequence_length;
// When S == 2W, there is no middle rows of blocks:
// [W][W]
// [W][W]
// We can use normal matrix multiplication in this case.
if (sequence_length == 2 * w) {
CHECK(_compat_rocblas_gemm_strided_batched_ex(rocblas,
rocblas_operation_transpose,
rocblas_operation_none,
sequence_length,
sequence_length,
head_size,
alpha,
k,
Atype,
head_size,
sequence_length * head_size,
q,
Btype,
head_size,
sequence_length * head_size,
beta_0,
scratch1,
Ctype,
sequence_length,
sequence_length * sequence_length,
batch_size * num_heads,
resultType,
algo));
} else { // sequence_length > 2 * w
for (int i = 0; i < batch_size; ++i) {
for (int j = 0; j < num_heads; ++j) {
const void* q_head = reinterpret_cast<const char*>(q) +
(i * x_offset + j * sequence_length * head_size + w * head_size) * element_size;
const void* k_head = reinterpret_cast<const char*>(k) +
(i * x_offset + j * sequence_length * head_size) * element_size;
void* qk_head = reinterpret_cast<char*>(scratch1) +
(i * y_offset + j * sequence_length * sequence_length + w * sequence_length) * element_size;
int count = (sequence_length - 2 * w) / w;
CHECK(_compat_rocblas_gemm_strided_batched_ex(rocblas,
rocblas_operation_transpose,
rocblas_operation_none,
3 * w, // m
w, // n
head_size, // k
alpha, // alpha
k_head, // A
Atype, // A type
head_size, // lda
w * head_size, // strideA
q_head, // B
Btype, // B type
head_size, // ldb
w * head_size, // strideB
beta_0, // beta
qk_head, // C
Ctype, // C type
sequence_length, // ldc
sequence_length * w + w, // strideC
count, // batch count
resultType,
algo));
}
}
CHECK(_compat_rocblas_gemm_strided_batched_ex(rocblas,
rocblas_operation_transpose,
rocblas_operation_none,
2 * w, // m
w, // n
head_size, // k
alpha, // alpha
k, // A
Atype, // A type
head_size, // lda
strideA, // strideA
q, // B
Btype, // B type
head_size, // ldb
strideB, // strideB
beta_0, // beta
scratch1, // C
Ctype, // C type
sequence_length, // ldc
strideC, // strideC
batch_size * num_heads, // batch count
resultType,
algo));
const void* q_head = reinterpret_cast<const char*>(q) + (last_block * w * head_size) * element_size;
const void* k_head = reinterpret_cast<const char*>(k) + ((last_block - 1) * w * head_size) * element_size;
void* qk_head = reinterpret_cast<char*>(scratch1) +
(last_block * w * sequence_length + (last_block - 1) * w) * element_size;
CHECK(_compat_rocblas_gemm_strided_batched_ex(rocblas,
rocblas_operation_transpose,
rocblas_operation_none,
2 * w,
w,
head_size,
alpha,
k_head,
Atype,
head_size,
strideA,
q_head,
Btype,
head_size,
strideB,
beta_0,
qk_head,
Ctype,
sequence_length,
strideC,
batch_size * num_heads,
resultType,
algo));
}
const int* batch_global_count = reinterpret_cast<const int*>(pinned_buffer);
// Global attention part
for (int i = 0; i < batch_size; ++i) {
if (batch_global_count[i] > 0) {
const void* q_batch = reinterpret_cast<const char*>(q) + (i * x_offset) * element_size;
const void* k_batch = reinterpret_cast<const char*>(k) + (i * x_offset) * element_size;
void* qk_batch = reinterpret_cast<char*>(scratch1) + (i * y_offset) * element_size;
// Local tokens attending global tokens
CHECK(_compat_rocblas_gemm_strided_batched_ex(rocblas,
rocblas_operation_transpose,
rocblas_operation_none,
batch_global_count[i],
sequence_length,
head_size,
alpha,
k_batch,
Atype,
head_size,
strideA,
q_batch,
Btype,
head_size,
strideB,
beta_0,
qk_batch,
Ctype,
sequence_length,
strideC,
num_heads,
resultType,
algo));
const void* global_q_batch = reinterpret_cast<const char*>(global_q) +
(i * num_heads * sequence_length * head_size) * element_size;
const void* global_k_batch = reinterpret_cast<const char*>(global_k) + (i * x_offset) * element_size;
int strideB_global = sequence_length * head_size;
// Global tokens attending everything
// This GEMMs need to be last to make sure all global token entries are re-written.
CHECK(_compat_rocblas_gemm_strided_batched_ex(rocblas,
rocblas_operation_transpose,
rocblas_operation_none,
sequence_length,
batch_global_count[i],
head_size,
alpha,
global_k_batch,
Atype,
head_size,
strideA,
global_q_batch,
Btype,
head_size,
strideB_global,
beta_0,
qk_batch,
Ctype,
sequence_length,
strideC,
num_heads,
resultType,
algo));
}
}
int dim0 = sequence_length * num_heads;
int dim1 = sequence_length;
void* softmax_out = scratch2;
const int blockSize = 64;
const int gridSize = batch_size * num_heads * sequence_length;
if (is_fp16) {
hipLaunchKernelGGL(HIP_KERNEL_NAME(LongformerSoftmaxSimpleKernel<__half, blockSize>), gridSize, blockSize, 0, stream,
global_attention,
global_index,
batch_global_num,
static_cast<const __half*>(scratch1),
static_cast<const __half*>(attention_mask),
static_cast<__half*>(softmax_out), scaler, dim0, dim1, attention_window);
} else {
hipLaunchKernelGGL(HIP_KERNEL_NAME(LongformerSoftmaxSimpleKernel<float, blockSize>), gridSize, blockSize, 0, stream,
global_attention,
global_index,
batch_global_num,
static_cast<const float*>(scratch1),
static_cast<const float*>(attention_mask),
static_cast<float*>(softmax_out), scaler, dim0, dim1, attention_window);
}
// Run the matrix multiply: output = softmax_out * v
// softmax_out: B x N x S x S
// v: B x N x S x H
// attn_out: B x N x S x H
// Calculation uses full Gemm (S == 2W) or sliding blocks (S > 2W) in a way similar to local attention part.
if (sequence_length == 2 * w) {
// convert col-major to row-major by swapping softmax_out and v
CHECK(_compat_rocblas_gemm_strided_batched_ex(rocblas,
rocblas_operation_none,
rocblas_operation_none,
head_size,
sequence_length,
sequence_length,
alpha,
v,
Atype,
head_size,
sequence_length * head_size,
softmax_out,
Btype,
sequence_length,
sequence_length * sequence_length,
beta_0,
output,
Ctype,
head_size,
sequence_length * head_size,
batch_size * num_heads,
resultType,
algo));
} else { // sequence_length > 2 * w
for (int i = 0; i < batch_size; ++i) {
for (int j = 0; j < num_heads; ++j) {
const void* v_head = reinterpret_cast<const char*>(v) +
(i * x_offset + j * head_size * sequence_length) * element_size;
size_t offset = (i * y_offset + j * sequence_length * sequence_length + w * sequence_length) * element_size;
const void* prob_head = reinterpret_cast<const char*>(softmax_out) + offset;
void* out_head = reinterpret_cast<char*>(output) +
(i * x_offset + j * head_size * sequence_length + w * head_size) * element_size;
int count = (sequence_length - 2 * w) / w;
CHECK(_compat_rocblas_gemm_strided_batched_ex(rocblas,
rocblas_operation_none,
rocblas_operation_none,
head_size,
w,
3 * w,
alpha,
v_head,
Atype,
head_size,
w * head_size,
prob_head,
Btype,
sequence_length,
sequence_length * w + w,
beta_0,
out_head,
Ctype,
head_size,
w * head_size,
count,
resultType,
algo));
}
}
CHECK(_compat_rocblas_gemm_strided_batched_ex(rocblas,
rocblas_operation_none,
rocblas_operation_none,
head_size,
w,
2 * w,
alpha,
v,
Atype,
head_size,
sequence_length * head_size,
softmax_out,
Btype,
sequence_length,
sequence_length * sequence_length,
beta_0,
output,
Ctype,
head_size,
sequence_length * head_size,
batch_size * num_heads,
resultType,
algo));
const void* v_head = reinterpret_cast<const char*>(v) + (last_block - 1) * w * head_size * element_size;
const void* prob_head = reinterpret_cast<const char*>(softmax_out) +
(sequence_length * last_block * w + (last_block - 1) * w) * element_size;
void* out_head = reinterpret_cast<char*>(output) + last_block * w * head_size * element_size;
CHECK(_compat_rocblas_gemm_strided_batched_ex(rocblas,
rocblas_operation_none,
rocblas_operation_none,
head_size,
w,
2 * w,
alpha,
v_head,
Atype,
head_size,
sequence_length * head_size,
prob_head,
Btype,
sequence_length,
sequence_length * sequence_length,
beta_0,
out_head,
Ctype,
head_size,
sequence_length * head_size,
batch_size * num_heads,
resultType,
algo));
}
for (int i = 0; i < batch_size; ++i) {
if (batch_global_count[i] > 0) {
int glob_longdim_mm = (last_block - 1) * w;
const void* v_head = reinterpret_cast<const char*>(v) + (i * x_offset) * element_size;
const void* prob_head = reinterpret_cast<const char*>(softmax_out) +
(i * y_offset + 2 * w * sequence_length) * element_size;
void* out_head = reinterpret_cast<char*>(output) + (i * x_offset + 2 * w * head_size) * element_size;
CHECK(_compat_rocblas_gemm_strided_batched_ex(rocblas,
rocblas_operation_none,
rocblas_operation_none,
head_size,
glob_longdim_mm,
batch_global_count[i],
alpha,
v_head,
Atype,
head_size,
sequence_length * head_size,
prob_head,
Btype,
sequence_length,
sequence_length * sequence_length,
beta_1,
out_head,
Ctype,
head_size,
sequence_length * head_size,
num_heads,
resultType,
algo));
// Global tokens
v_head = reinterpret_cast<const char*>(global_v) + (i * x_offset) * element_size;
prob_head = reinterpret_cast<const char*>(softmax_out) + (i * y_offset) * element_size;
out_head = reinterpret_cast<char*>(output) + (i * x_offset) * element_size;
CHECK(_compat_rocblas_gemm_strided_batched_ex(rocblas,
rocblas_operation_none,
rocblas_operation_none,
head_size,
batch_global_count[i],
sequence_length, // Re-write entries completely
alpha,
v_head,
Atype,
head_size,
sequence_length * head_size,
prob_head,
Btype,
sequence_length,
sequence_length * sequence_length,
beta_0, // Use beta=0 to overwrite
out_head, // Here assumes global tokens are at the beginning of sequence.
Ctype,
head_size,
sequence_length * head_size,
num_heads,
resultType,
algo));
}
}
return Status::OK();
}
} // namespace rocm
} // namespace contrib
} // namespace onnxruntime
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment