Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
onnxruntime_v14
Commits
1a91fcc2
Commit
1a91fcc2
authored
Jul 25, 2023
by
gaoqiong
Browse files
add dtk所需文件
parent
a144865d
Pipeline
#492
failed with stages
in 0 seconds
Changes
280
Pipelines
1
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
4579 additions
and
0 deletions
+4579
-0
build/Linux/Release/amdgpu/onnxruntime/contrib_ops/rocm/activation/activations.cc
...pu/onnxruntime/contrib_ops/rocm/activation/activations.cc
+61
-0
build/Linux/Release/amdgpu/onnxruntime/contrib_ops/rocm/activation/activations.h
...gpu/onnxruntime/contrib_ops/rocm/activation/activations.h
+96
-0
build/Linux/Release/amdgpu/onnxruntime/contrib_ops/rocm/activation/activations_impl.cu
...nxruntime/contrib_ops/rocm/activation/activations_impl.cu
+90
-0
build/Linux/Release/amdgpu/onnxruntime/contrib_ops/rocm/activation/activations_impl.h
...nnxruntime/contrib_ops/rocm/activation/activations_impl.h
+30
-0
build/Linux/Release/amdgpu/onnxruntime/contrib_ops/rocm/aten_ops/aten_op.cc
...e/amdgpu/onnxruntime/contrib_ops/rocm/aten_ops/aten_op.cc
+18
-0
build/Linux/Release/amdgpu/onnxruntime/contrib_ops/rocm/bert/add_bias_transpose.cu
...u/onnxruntime/contrib_ops/rocm/bert/add_bias_transpose.cu
+459
-0
build/Linux/Release/amdgpu/onnxruntime/contrib_ops/rocm/bert/add_bias_transpose.h
...pu/onnxruntime/contrib_ops/rocm/bert/add_bias_transpose.h
+43
-0
build/Linux/Release/amdgpu/onnxruntime/contrib_ops/rocm/bert/attention_concat.cu
...gpu/onnxruntime/contrib_ops/rocm/bert/attention_concat.cu
+250
-0
build/Linux/Release/amdgpu/onnxruntime/contrib_ops/rocm/bert/attention_impl.h
...amdgpu/onnxruntime/contrib_ops/rocm/bert/attention_impl.h
+150
-0
build/Linux/Release/amdgpu/onnxruntime/contrib_ops/rocm/bert/attention_transpose.cu
.../onnxruntime/contrib_ops/rocm/bert/attention_transpose.cu
+304
-0
build/Linux/Release/amdgpu/onnxruntime/contrib_ops/rocm/bert/bert_padding.cu
.../amdgpu/onnxruntime/contrib_ops/rocm/bert/bert_padding.cu
+343
-0
build/Linux/Release/amdgpu/onnxruntime/contrib_ops/rocm/bert/bert_padding.h
...e/amdgpu/onnxruntime/contrib_ops/rocm/bert/bert_padding.h
+48
-0
build/Linux/Release/amdgpu/onnxruntime/contrib_ops/rocm/bert/decoder_attention.cc
...pu/onnxruntime/contrib_ops/rocm/bert/decoder_attention.cc
+392
-0
build/Linux/Release/amdgpu/onnxruntime/contrib_ops/rocm/bert/decoder_attention.h
...gpu/onnxruntime/contrib_ops/rocm/bert/decoder_attention.h
+26
-0
build/Linux/Release/amdgpu/onnxruntime/contrib_ops/rocm/bert/layer_norm.cuh
...e/amdgpu/onnxruntime/contrib_ops/rocm/bert/layer_norm.cuh
+158
-0
build/Linux/Release/amdgpu/onnxruntime/contrib_ops/rocm/bert/longformer_attention.cc
...onnxruntime/contrib_ops/rocm/bert/longformer_attention.cc
+295
-0
build/Linux/Release/amdgpu/onnxruntime/contrib_ops/rocm/bert/longformer_attention.h
.../onnxruntime/contrib_ops/rocm/bert/longformer_attention.h
+29
-0
build/Linux/Release/amdgpu/onnxruntime/contrib_ops/rocm/bert/longformer_attention_impl.cu
...untime/contrib_ops/rocm/bert/longformer_attention_impl.cu
+1059
-0
build/Linux/Release/amdgpu/onnxruntime/contrib_ops/rocm/bert/longformer_attention_impl.h
...runtime/contrib_ops/rocm/bert/longformer_attention_impl.h
+53
-0
build/Linux/Release/amdgpu/onnxruntime/contrib_ops/rocm/bert/longformer_attention_softmax.cu
...ime/contrib_ops/rocm/bert/longformer_attention_softmax.cu
+675
-0
No files found.
build/Linux/Release/amdgpu/onnxruntime/contrib_ops/rocm/activation/activations.cc
0 → 100644
View file @
1a91fcc2
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "activations.h"
#include "core/framework/op_kernel.h"
using
namespace
onnxruntime
::
rocm
;
namespace
onnxruntime
{
namespace
contrib
{
namespace
rocm
{
#define REGISTER_ACTIVATION_KERNEL(x, ver, domain, T) \
ONNX_OPERATOR_TYPED_KERNEL_EX( \
x, \
domain, \
ver, \
T, \
kRocmExecutionProvider, \
(*KernelDefBuilder::Create()) \
.TypeConstraint("T", DataTypeImpl::GetTensorType<T>()) \
.MayInplace(0, 0), \
x<T>);
#define UNARY_ACTIVATION_COMPUTE(x, T) \
template <> \
Status x<T>::ComputeInternal(OpKernelContext* context) const { \
UnaryElementwisePreparation p; \
ORT_RETURN_IF_ERROR(UnaryElementwise::Prepare(context, &p)); \
Ctx##x func_ctx = MakeFuncCtx(); \
Impl_##x<typename ToHipType<T>::MappedType>( \
Stream(), \
reinterpret_cast<const typename ToHipType<T>::MappedType*>(p.input_tensor->Data<T>()), \
reinterpret_cast<typename ToHipType<T>::MappedType*>(p.output_tensor->MutableData<T>()), \
&func_ctx, p.output_tensor->Shape().Size()); \
\
return Status::OK(); \
}
#define UNARY_ACTIVATION_OP_TYPED(name, ver, domain, T) \
REGISTER_ACTIVATION_KERNEL(name, ver, domain, T) \
UNARY_ACTIVATION_COMPUTE(name, T)
#define UNARY_ACTIVATION_OP_HFD(name, ver, domain) \
UNARY_ACTIVATION_OP_TYPED(name, ver, domain, MLFloat16) \
UNARY_ACTIVATION_OP_TYPED(name, ver, domain, float) \
UNARY_ACTIVATION_OP_TYPED(name, ver, domain, double)
UNARY_ACTIVATION_OP_HFD
(
Affine
,
1
,
kOnnxDomain
);
UNARY_ACTIVATION_OP_HFD
(
ParametricSoftplus
,
1
,
kOnnxDomain
);
UNARY_ACTIVATION_OP_HFD
(
ScaledTanh
,
1
,
kOnnxDomain
);
UNARY_ACTIVATION_OP_HFD
(
Gelu
,
1
,
kMSDomain
);
UNARY_ACTIVATION_OP_HFD
(
QuickGelu
,
1
,
kMSDomain
);
REGISTER_ACTIVATION_KERNEL
(
ThresholdedRelu
,
1
,
kOnnxDomain
,
MLFloat16
)
REGISTER_ACTIVATION_KERNEL
(
ThresholdedRelu
,
1
,
kOnnxDomain
,
float
)
REGISTER_ACTIVATION_KERNEL
(
ThresholdedRelu
,
1
,
kOnnxDomain
,
double
)
}
//namespace rocm
}
// namespace contrib
}
// namespace onnxruntime
build/Linux/Release/amdgpu/onnxruntime/contrib_ops/rocm/activation/activations.h
0 → 100644
View file @
1a91fcc2
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include "core/providers/rocm/rocm_common.h"
#include "core/providers/rocm/math/unary_elementwise_ops.h"
#include "core/providers/rocm/math/binary_elementwise_ops.h"
#include "core/providers/rocm/activation/activations.h"
#include "activations_impl.h"
using
namespace
onnxruntime
::
rocm
;
namespace
onnxruntime
{
namespace
contrib
{
namespace
rocm
{
template
<
typename
T
>
class
Affine
final
:
public
UnaryElementwise
{
public:
Affine
(
const
OpKernelInfo
&
info
)
:
UnaryElementwise
(
info
)
{
ORT_ENFORCE
(
info
.
GetAttr
(
"alpha"
,
&
alpha_
).
IsOK
());
ORT_ENFORCE
(
info
.
GetAttr
(
"beta"
,
&
beta_
).
IsOK
());
}
Status
ComputeInternal
(
OpKernelContext
*
context
)
const
override
;
private:
MAKE_FUNC_CTX_ALPHA_BETA
()
float
alpha_
;
float
beta_
;
};
template
<
typename
T
>
class
ParametricSoftplus
final
:
public
UnaryElementwise
{
public:
ParametricSoftplus
(
const
OpKernelInfo
&
info
)
:
UnaryElementwise
(
info
)
{
ORT_ENFORCE
(
info
.
GetAttr
(
"alpha"
,
&
alpha_
).
IsOK
());
ORT_ENFORCE
(
info
.
GetAttr
(
"beta"
,
&
beta_
).
IsOK
());
}
Status
ComputeInternal
(
OpKernelContext
*
context
)
const
override
;
private:
MAKE_FUNC_CTX_ALPHA_BETA
()
float
alpha_
;
float
beta_
;
};
template
<
typename
T
>
class
ScaledTanh
final
:
public
UnaryElementwise
{
public:
ScaledTanh
(
const
OpKernelInfo
&
info
)
:
UnaryElementwise
(
info
)
{
ORT_ENFORCE
(
info
.
GetAttr
(
"alpha"
,
&
alpha_
).
IsOK
());
ORT_ENFORCE
(
info
.
GetAttr
(
"beta"
,
&
beta_
).
IsOK
());
}
Status
ComputeInternal
(
OpKernelContext
*
context
)
const
override
;
private:
MAKE_FUNC_CTX_ALPHA_BETA
()
float
alpha_
;
float
beta_
;
};
template
<
typename
T
>
class
Gelu
final
:
public
UnaryElementwise
{
public:
Gelu
(
const
OpKernelInfo
&
info
)
:
UnaryElementwise
(
info
)
{}
Status
ComputeInternal
(
OpKernelContext
*
context
)
const
override
;
private:
MAKE_FUNC_CTX_NULL
()
};
template
<
typename
T
>
class
QuickGelu
final
:
public
UnaryElementwise
{
public:
QuickGelu
(
const
OpKernelInfo
&
info
)
:
UnaryElementwise
(
info
)
{
alpha_
=
info
.
GetAttrOrDefault
<
float
>
(
"alpha"
,
1.702
f
);
}
Status
ComputeInternal
(
OpKernelContext
*
context
)
const
override
;
private:
MAKE_FUNC_CTX_ALPHA
()
float
alpha_
;
};
}
// namespace rocm
}
// namespace contrib
}
// namespace onnxruntime
build/Linux/Release/amdgpu/onnxruntime/contrib_ops/rocm/activation/activations_impl.cu
0 → 100644
View file @
1a91fcc2
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include <hip/hip_runtime.h>
#include "activations_impl.h"
#include "core/providers/rocm/cu_inc/common.cuh"
#include "core/providers/rocm/cu_inc/unary_elementwise_impl.cuh"
using
namespace
onnxruntime
::
rocm
;
namespace
onnxruntime
{
namespace
contrib
{
namespace
rocm
{
template
<
typename
T
>
struct
OP_Affine
:
public
CtxAffine
{
__device__
__inline__
T
operator
()(
const
T
&
a
)
const
{
return
a
*
(
T
)
alpha
+
(
T
)
beta
;
}
};
template
<
typename
T
>
struct
OP_ParametricSoftplus
:
public
CtxParametricSoftplus
{
__device__
__inline__
T
operator
()(
const
T
&
a
)
const
{
if
(
a
>
(
T
)
0
)
return
(
T
)
alpha
*
(
a
*
(
T
)
beta
+
_Log
(
_Exp
(
-
a
*
(
T
)
beta
)
+
(
T
)
1
));
else
return
(
T
)
alpha
*
_Log
(
_Exp
(
a
*
(
T
)
beta
)
+
(
T
)
1
);
}
};
template
<
typename
T
>
struct
OP_ScaledTanh
:
public
CtxScaledTanh
{
__device__
__inline__
T
operator
()(
const
T
&
a
)
const
{
return
(
T
)
alpha
*
_Tanh
(
a
*
(
T
)
beta
);
}
};
template
<
typename
T
>
struct
OP_Gelu
:
public
CtxGelu
{
__device__
__inline__
T
operator
()(
const
T
&
a
)
const
{
return
_Gelu
(
a
);
}
};
template
<
>
struct
OP_Gelu
<
half
>
:
public
CtxGelu
{
__device__
__inline__
half
operator
()(
const
half
&
a
)
const
{
return
static_cast
<
half
>
(
_Gelu
(
static_cast
<
float
>
(
a
)));
}
};
template
<
typename
T
>
struct
OP_QuickGelu
:
public
CtxQuickGelu
{
__device__
__inline__
T
operator
()(
const
T
&
a
)
const
{
T
v
=
a
*
static_cast
<
T
>
(
alpha
);
T
one
=
static_cast
<
T
>
(
1.
f
);
T
zero
=
static_cast
<
T
>
(
0.
f
);
T
sigmoid
=
v
>=
zero
?
one
/
(
one
+
_Exp
(
-
v
))
:
one
-
one
/
(
one
+
_Exp
(
v
));
return
a
*
sigmoid
;
}
};
#define UNARY_ACTIVATION_IMPL(name) \
UNARY_ACTIVATION_IMPL_DECLARATION(name) { \
UnaryElementWiseImpl(stream, \
input_data, \
output_data, \
*reinterpret_cast<const OP_##name<T>*>(func_ctx), \
count); \
}
#define SPECIALIZED_UNARY_ACTIVATION_IMPL(name, T) \
template void Impl_##name<T>(hipStream_t stream, const T* input_data, T* output_data, const Ctx##name* func_ctx, size_t count);
#define SPECIALIZED_UNARY_ACTIVATIONL_HFD(name) \
SPECIALIZED_UNARY_ACTIVATION_IMPL(name, half) \
SPECIALIZED_UNARY_ACTIVATION_IMPL(name, float) \
SPECIALIZED_UNARY_ACTIVATION_IMPL(name, double)
#define UNARY_ACTIVATION_OP_NAME(name) \
UNARY_ACTIVATION_IMPL(name); \
SPECIALIZED_UNARY_ACTIVATIONL_HFD(name)
UNARY_CONTRIB_ACTIVATION_OPS
()
#undef UNARY_ACTIVATION_OP_NAME
}
// namespace rocm
}
// namespace contrib
}
// namespace onnxruntime
build/Linux/Release/amdgpu/onnxruntime/contrib_ops/rocm/activation/activations_impl.h
0 → 100644
View file @
1a91fcc2
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include "core/providers/rocm/activation/activations_impl.h"
namespace
onnxruntime
{
namespace
contrib
{
namespace
rocm
{
typedef
onnxruntime
::
rocm
::
CtxAlphaBeta
CtxAffine
;
typedef
onnxruntime
::
rocm
::
CtxAlphaBeta
CtxParametricSoftplus
;
typedef
onnxruntime
::
rocm
::
CtxAlphaBeta
CtxScaledTanh
;
typedef
onnxruntime
::
rocm
::
CtxNull
CtxGelu
;
typedef
onnxruntime
::
rocm
::
CtxAlpha
CtxQuickGelu
;
#define UNARY_CONTRIB_ACTIVATION_OPS() \
UNARY_ACTIVATION_OP_NAME(ScaledTanh) \
UNARY_ACTIVATION_OP_NAME(Affine) \
UNARY_ACTIVATION_OP_NAME(ParametricSoftplus) \
UNARY_ACTIVATION_OP_NAME(Gelu) \
UNARY_ACTIVATION_OP_NAME(QuickGelu)
#define UNARY_ACTIVATION_OP_NAME(name) UNARY_ACTIVATION_IMPL_DECLARATION(name);
UNARY_CONTRIB_ACTIVATION_OPS
()
#undef UNARY_ACTIVATION_OP_NAME
}
// namespace rocm
}
// namespace contrib
}
// namespace onnxruntime
build/Linux/Release/amdgpu/onnxruntime/contrib_ops/rocm/aten_ops/aten_op.cc
0 → 100644
View file @
1a91fcc2
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "core/providers/shared_library/provider_api.h"
#include "contrib_ops/cpu/aten_ops/aten_op.h"
namespace
onnxruntime
{
namespace
contrib
{
namespace
rocm
{
ONNX_OPERATOR_KERNEL_EX
(
ATen
,
kPytorchAtenDomain
,
1
,
kRocmExecutionProvider
,
(
*
KernelDefBuilder
::
Create
()).
TypeConstraint
(
"T"
,
DataTypeImpl
::
AllTensorAndSequenceTensorTypes
()),
onnxruntime
::
contrib
::
ATen
);
}
// namespace rocm
}
// namespace contrib
}
// namespace onnxruntime
build/Linux/Release/amdgpu/onnxruntime/contrib_ops/rocm/bert/add_bias_transpose.cu
0 → 100644
View file @
1a91fcc2
#include "hip/hip_runtime.h"
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "core/providers/rocm/rocm_common.h"
#include "core/providers/rocm/cu_inc/common.cuh"
#include "contrib_ops/rocm/bert/add_bias_transpose.h"
namespace
onnxruntime
{
namespace
rocm
{
struct
__align__
(
8
)
Half4
{
half2
x
;
half2
y
;
};
__device__
__forceinline__
Half4
operator
+
(
const
Half4
&
a
,
const
Half4
&
b
)
{
Half4
r
;
r
.
x
=
a
.
x
+
b
.
x
;
r
.
y
=
a
.
y
+
b
.
y
;
return
r
;
}
__device__
__forceinline__
float2
operator
+
(
const
float2
&
a
,
const
float2
&
b
)
{
return
make_float2
(
a
.
x
+
b
.
x
,
a
.
y
+
b
.
y
);
}
__device__
__forceinline__
float4
operator
+
(
const
float4
&
a
,
const
float4
&
b
)
{
return
make_float4
(
a
.
x
+
b
.
x
,
a
.
y
+
b
.
y
,
a
.
z
+
b
.
z
,
a
.
w
+
b
.
w
);
}
}
// namespace rocm
}
// namespace onnxruntime
using
namespace
onnxruntime
::
rocm
;
namespace
onnxruntime
{
namespace
contrib
{
namespace
rocm
{
template
<
typename
T
>
__global__
void
AddBiasTransposeTrt
(
const
T
*
input
,
const
T
*
biases
,
T
*
output
)
{
// Input: BxSxMxNxH (Format 2)
// Output: BxSxNxMxH
// B is batch_size, S is sequence_length, M is number of matrices, N is num_heads, H is head_size
int
n
=
threadIdx
.
y
;
int
s
=
blockIdx
.
x
;
int
b
=
blockIdx
.
y
;
int
m
=
blockIdx
.
z
;
// matrix id
const
int
H
=
blockDim
.
x
;
const
int
N
=
blockDim
.
y
;
const
int
S
=
gridDim
.
x
;
const
int
M
=
gridDim
.
z
;
const
int
NH
=
N
*
H
;
const
int
offset
=
(
b
*
S
+
s
)
*
M
*
NH
;
const
int
in_offset
=
offset
+
m
*
NH
+
n
*
H
;
const
int
out_offset
=
offset
+
(
n
*
M
+
m
)
*
H
;
const
int
h
=
threadIdx
.
x
;
if
(
h
<
H
)
{
output
[
out_offset
+
h
]
=
input
[
in_offset
+
h
]
+
biases
[
m
*
NH
+
n
*
H
+
h
];
}
}
template
<
typename
T
>
__global__
void
AddBiasTransposeTrtLarge
(
const
int
head_size
,
const
T
*
input
,
const
T
*
biases
,
T
*
output
)
{
int
n
=
threadIdx
.
y
;
int
s
=
blockIdx
.
x
;
int
b
=
blockIdx
.
y
;
int
m
=
blockIdx
.
z
;
const
int
stride
=
blockDim
.
x
;
const
int
H
=
head_size
;
const
int
N
=
blockDim
.
y
;
const
int
S
=
gridDim
.
x
;
const
int
M
=
gridDim
.
z
;
const
int
NH
=
N
*
H
;
const
int
offset
=
(
b
*
S
+
s
)
*
M
*
NH
;
const
int
in_offset
=
offset
+
m
*
NH
+
n
*
H
;
const
int
out_offset
=
offset
+
(
n
*
M
+
m
)
*
H
;
int
h
=
threadIdx
.
x
;
while
(
h
<
H
)
{
output
[
out_offset
+
h
]
=
input
[
in_offset
+
h
]
+
biases
[
m
*
NH
+
n
*
H
+
h
];
h
+=
stride
;
}
}
template
<
typename
T
>
__global__
void
AddBiasTransposeTrt
(
const
T
*
query
,
const
T
*
key
,
const
T
*
value
,
const
T
*
biases
,
T
*
output
)
{
// Q: BxSxNxH
// K: BxSxNxH
// V: BxSxNxH
// Output: BxSxNxMxH
// B is batch_size, S is sequence_length, M is number of matrices (3), N is num_heads, H is head_size
int
n
=
threadIdx
.
y
;
int
s
=
blockIdx
.
x
;
int
b
=
blockIdx
.
y
;
int
m
=
blockIdx
.
z
;
// matrix id
const
int
H
=
blockDim
.
x
;
const
int
N
=
blockDim
.
y
;
const
int
S
=
gridDim
.
x
;
const
int
M
=
gridDim
.
z
;
const
T
*
input
=
(
m
==
0
?
query
:
(
m
==
1
?
key
:
value
));
const
int
NH
=
N
*
H
;
const
int
in_offset
=
(
b
*
S
+
s
)
*
NH
+
n
*
H
;
const
int
out_offset
=
(
b
*
S
+
s
)
*
M
*
NH
+
(
n
*
M
+
m
)
*
H
;
const
int
h
=
threadIdx
.
x
;
if
(
h
<
H
)
{
output
[
out_offset
+
h
]
=
input
[
in_offset
+
h
]
+
biases
[
m
*
NH
+
n
*
H
+
h
];
}
}
template
<
typename
T
>
__global__
void
AddBiasTransposeTrtLarge
(
const
int
head_size
,
const
T
*
query
,
const
T
*
key
,
const
T
*
value
,
const
T
*
biases
,
T
*
output
)
{
int
n
=
threadIdx
.
y
;
int
s
=
blockIdx
.
x
;
int
b
=
blockIdx
.
y
;
int
m
=
blockIdx
.
z
;
// matrix id
const
int
stride
=
blockDim
.
x
;
const
int
H
=
head_size
;
const
int
N
=
blockDim
.
y
;
const
int
S
=
gridDim
.
x
;
const
int
M
=
gridDim
.
z
;
const
T
*
input
=
(
m
==
0
?
query
:
(
m
==
1
?
key
:
value
));
const
int
NH
=
N
*
H
;
const
int
in_offset
=
(
b
*
S
+
s
)
*
NH
+
n
*
H
;
const
int
out_offset
=
(
b
*
S
+
s
)
*
M
*
NH
+
(
n
*
M
+
m
)
*
H
;
int
h
=
threadIdx
.
x
;
if
(
h
<
H
)
{
output
[
out_offset
+
h
]
=
input
[
in_offset
+
h
]
+
biases
[
m
*
NH
+
n
*
H
+
h
];
h
+=
stride
;
}
}
template
<
typename
T
>
__global__
void
AddBiasTransposeQKV
(
const
T
*
input
,
const
T
*
biases
,
T
*
output
)
{
// Input: BxSxMxNxH (Format 1)
// Output: MxBxNxSxH
// B is batch_size, S is sequence_length, M is number of matrices, N is num_heads, H is head_size
int
n
=
threadIdx
.
y
;
int
s
=
blockIdx
.
x
;
int
b
=
blockIdx
.
y
;
int
m
=
blockIdx
.
z
;
// matrix id
const
int
head_size
=
blockDim
.
x
;
const
int
num_heads
=
blockDim
.
y
;
const
int
sequence_length
=
gridDim
.
x
;
const
int
batch_size
=
gridDim
.
y
;
const
int
M
=
gridDim
.
z
;
const
int
H
=
head_size
;
const
int
NH
=
num_heads
*
head_size
;
const
int
NHS
=
NH
*
sequence_length
;
int
in_offset
=
n
*
head_size
+
(
m
+
s
*
M
)
*
NH
+
b
*
NHS
*
M
;
const
int
out_offset
=
s
*
head_size
+
n
*
sequence_length
*
H
+
b
*
NHS
+
m
*
NHS
*
batch_size
;
const
int
h
=
threadIdx
.
x
;
if
(
h
<
head_size
)
{
output
[
out_offset
+
h
]
=
input
[
in_offset
+
h
]
+
biases
[
m
*
NH
+
n
*
H
+
h
];
}
}
template
<
typename
T
>
__global__
void
AddBiasTransposeQKV
(
const
T
*
input
,
const
T
*
biases
,
T
*
output
,
int
v_head_size
)
{
// Input: BxSxMxNxH (Format 1)
// Output: MxBxNxSxH
// B is batch_size, S is sequence_length, M is number of matrices, N is num_heads, H is head_size
int
n
=
threadIdx
.
y
;
// head_num_id
int
s
=
blockIdx
.
x
;
// sequence_id
int
b
=
blockIdx
.
y
;
// batch_id
int
m
=
blockIdx
.
z
;
// matrix id (Q=0, K=1, V=2)
const
int
h
=
threadIdx
.
x
;
// head_element_id
const
int
qk_head_size
=
blockDim
.
x
;
const
int
num_heads
=
blockDim
.
y
;
const
int
sequence_length
=
gridDim
.
x
;
const
int
batch_size
=
gridDim
.
y
;
const
int
head_size
=
(
m
==
2
?
v_head_size
:
qk_head_size
);
const
int
total_head_size
=
num_heads
*
(
qk_head_size
+
qk_head_size
+
v_head_size
);
int
in_offset
;
int
out_offset
;
int
bias_offset
;
in_offset
=
b
*
(
total_head_size
*
sequence_length
)
+
// B
s
*
(
total_head_size
)
+
// S
m
*
(
qk_head_size
*
num_heads
)
+
// M
n
*
head_size
+
// N
h
;
// H
out_offset
=
m
*
(
num_heads
*
qk_head_size
*
sequence_length
*
batch_size
)
+
// M
b
*
(
num_heads
*
head_size
*
sequence_length
)
+
// B
n
*
(
sequence_length
*
head_size
)
+
// N
s
*
(
head_size
)
+
// S
h
;
// H
bias_offset
=
m
*
(
num_heads
*
qk_head_size
)
+
// M
n
*
(
head_size
)
+
// N
h
;
// H
if
(
h
<
head_size
)
{
output
[
out_offset
]
=
input
[
in_offset
]
+
biases
[
bias_offset
];
}
}
template
<
typename
T
>
__global__
void
AddBiasTransposeQKVLarge
(
const
int
head_size
,
const
T
*
input
,
const
T
*
biases
,
T
*
output
)
{
int
n
=
threadIdx
.
y
;
int
s
=
blockIdx
.
x
;
int
b
=
blockIdx
.
y
;
int
m
=
blockIdx
.
z
;
const
int
stride
=
blockDim
.
x
;
const
int
num_heads
=
blockDim
.
y
;
const
int
sequence_length
=
gridDim
.
x
;
const
int
batch_size
=
gridDim
.
y
;
const
int
M
=
gridDim
.
z
;
const
int
H
=
head_size
;
const
int
NH
=
num_heads
*
H
;
const
int
NHS
=
NH
*
sequence_length
;
int
in_offset
=
n
*
H
+
(
m
+
s
*
M
)
*
NH
+
b
*
NHS
*
M
;
const
int
out_offset
=
s
*
H
+
n
*
sequence_length
*
H
+
b
*
NHS
+
m
*
NHS
*
batch_size
;
int
h
=
threadIdx
.
x
;
while
(
h
<
H
)
{
output
[
out_offset
+
h
]
=
input
[
in_offset
+
h
]
+
biases
[
m
*
NH
+
n
*
H
+
h
];
h
+=
stride
;
}
}
template
<
typename
T
>
__global__
void
AddBiasTranspose
(
const
T
*
input
,
const
T
*
biases
,
T
*
output
)
{
// Input: MxBxSxNxH (Format 0)
// Output: MxBxNxSxH
// B is batch_size, S is sequence_length, M is number of matrices, N is num_heads, H is head_size
int
n
=
threadIdx
.
y
;
int
s
=
blockIdx
.
x
;
int
b
=
blockIdx
.
y
;
int
m
=
blockIdx
.
z
;
const
int
head_size
=
blockDim
.
x
;
const
int
num_heads
=
blockDim
.
y
;
const
int
sequence_length
=
gridDim
.
x
;
const
int
batch_size
=
gridDim
.
y
;
const
int
H
=
head_size
;
const
int
NH
=
num_heads
*
head_size
;
const
int
NHS
=
NH
*
sequence_length
;
int
in_offset
=
n
*
H
+
s
*
NH
+
(
b
+
m
*
batch_size
)
*
NHS
;
const
int
out_offset
=
s
*
H
+
n
*
sequence_length
*
H
+
(
b
+
m
*
batch_size
)
*
NHS
;
const
int
h
=
threadIdx
.
x
;
if
(
h
<
head_size
)
{
output
[
out_offset
+
h
]
=
input
[
in_offset
+
h
]
+
biases
[
m
*
NH
+
n
*
H
+
h
];
}
}
template
<
typename
T
>
__global__
void
AddBiasTransposeLarge
(
const
int
head_size
,
const
T
*
input
,
const
T
*
biases
,
T
*
output
)
{
int
n
=
threadIdx
.
y
;
int
s
=
blockIdx
.
x
;
int
b
=
blockIdx
.
y
;
int
m
=
blockIdx
.
z
;
const
int
stride
=
blockDim
.
x
;
const
int
num_heads
=
blockDim
.
y
;
const
int
sequence_length
=
gridDim
.
x
;
const
int
batch_size
=
gridDim
.
y
;
const
int
H
=
head_size
;
const
int
NH
=
num_heads
*
H
;
const
int
NHS
=
NH
*
sequence_length
;
int
in_offset
=
n
*
H
+
s
*
NH
+
(
b
+
m
*
batch_size
)
*
NHS
;
const
int
out_offset
=
(
s
+
n
*
sequence_length
)
*
H
+
(
b
+
m
*
batch_size
)
*
NHS
;
int
h
=
threadIdx
.
x
;
while
(
h
<
H
)
{
output
[
out_offset
+
h
]
=
input
[
in_offset
+
h
]
+
biases
[
m
*
NH
+
n
*
H
+
h
];
h
+=
stride
;
}
}
template
<
typename
T
>
void
InvokeAddBiasTranspose
(
hipStream_t
stream
,
const
int
num_matrices
,
const
int
format
,
const
int
max_threads_per_block
,
const
int
batch_size
,
const
int
sequence_length
,
const
int
num_heads
,
const
int
qk_head_size
,
const
T
*
input
,
const
T
*
biases
,
T
*
output
,
const
int
v_head_size
)
{
const
dim3
grid
(
sequence_length
,
batch_size
,
num_matrices
);
if
(
qk_head_size
*
num_heads
<=
max_threads_per_block
)
{
const
dim3
block
(
qk_head_size
,
num_heads
,
1
);
if
(
format
==
2
)
{
hipLaunchKernelGGL
(
HIP_KERNEL_NAME
(
AddBiasTransposeTrt
<
T
>
),
grid
,
block
,
0
,
stream
,
input
,
biases
,
output
);
}
else
if
(
format
==
1
)
{
if
(
v_head_size
==
-
1
||
qk_head_size
==
v_head_size
)
{
hipLaunchKernelGGL
(
HIP_KERNEL_NAME
(
AddBiasTransposeQKV
<
T
>
),
grid
,
block
,
0
,
stream
,
input
,
biases
,
output
);
}
else
{
hipLaunchKernelGGL
(
HIP_KERNEL_NAME
(
AddBiasTransposeQKV
<
T
>
),
grid
,
block
,
0
,
stream
,
input
,
biases
,
output
,
v_head_size
);
}
}
else
{
hipLaunchKernelGGL
(
HIP_KERNEL_NAME
(
AddBiasTranspose
<
T
>
),
grid
,
block
,
0
,
stream
,
input
,
biases
,
output
);
}
}
else
{
const
dim3
block
(
CeilDiv
(
max_threads_per_block
,
num_heads
),
num_heads
,
1
);
if
(
format
==
2
)
{
hipLaunchKernelGGL
(
HIP_KERNEL_NAME
(
AddBiasTransposeTrtLarge
<
T
>
),
grid
,
block
,
0
,
stream
,
qk_head_size
,
input
,
biases
,
output
);
}
else
if
(
format
==
1
)
{
if
(
v_head_size
==
-
1
||
qk_head_size
==
v_head_size
)
{
hipLaunchKernelGGL
(
HIP_KERNEL_NAME
(
AddBiasTransposeQKVLarge
<
T
>
),
grid
,
block
,
0
,
stream
,
qk_head_size
,
input
,
biases
,
output
);
}
else
{
ORT_THROW
(
"AddBiasTranspose (format 1) not implemented for hidden_size > max_threads_per_block"
);
}
}
else
{
hipLaunchKernelGGL
(
HIP_KERNEL_NAME
(
AddBiasTransposeLarge
<
T
>
),
grid
,
block
,
0
,
stream
,
qk_head_size
,
input
,
biases
,
output
);
}
}
}
template
<
>
void
LaunchAddBiasTranspose
(
hipStream_t
stream
,
const
int
num_matrices
,
const
int
format
,
const
int
max_threads_per_block
,
const
int
batch_size
,
const
int
sequence_length
,
const
int
num_heads
,
const
int
qk_head_size
,
const
half
*
input
,
const
half
*
biases
,
half
*
output
,
bool
enable_half4
,
const
int
v_head_size
)
{
if
(
enable_half4
&&
0
==
(
qk_head_size
%
4
)
&&
0
==
(
v_head_size
%
4
))
{
const
int
H
=
qk_head_size
/
4
;
const
int
H_v
=
v_head_size
/
4
;
const
Half4
*
input2
=
reinterpret_cast
<
const
Half4
*>
(
input
);
const
Half4
*
biases2
=
reinterpret_cast
<
const
Half4
*>
(
biases
);
Half4
*
output2
=
reinterpret_cast
<
Half4
*>
(
output
);
InvokeAddBiasTranspose
<
Half4
>
(
stream
,
num_matrices
,
format
,
max_threads_per_block
,
batch_size
,
sequence_length
,
num_heads
,
H
,
input2
,
biases2
,
output2
,
H_v
);
}
else
if
(
0
==
(
qk_head_size
&
1
)
&&
0
==
(
v_head_size
%
1
))
{
const
int
H
=
qk_head_size
/
2
;
const
int
H_v
=
v_head_size
/
2
;
const
half2
*
input2
=
reinterpret_cast
<
const
half2
*>
(
input
);
const
half2
*
biases2
=
reinterpret_cast
<
const
half2
*>
(
biases
);
half2
*
output2
=
reinterpret_cast
<
half2
*>
(
output
);
InvokeAddBiasTranspose
<
half2
>
(
stream
,
num_matrices
,
format
,
max_threads_per_block
,
batch_size
,
sequence_length
,
num_heads
,
H
,
input2
,
biases2
,
output2
,
H_v
);
}
else
{
InvokeAddBiasTranspose
<
half
>
(
stream
,
num_matrices
,
format
,
max_threads_per_block
,
batch_size
,
sequence_length
,
num_heads
,
qk_head_size
,
input
,
biases
,
output
,
v_head_size
);
}
}
template
<
>
void
LaunchAddBiasTranspose
(
hipStream_t
stream
,
const
int
num_matrices
,
const
int
format
,
const
int
max_threads_per_block
,
const
int
batch_size
,
const
int
sequence_length
,
const
int
num_heads
,
const
int
qk_head_size
,
const
float
*
input
,
const
float
*
biases
,
float
*
output
,
bool
/*enable_half4*/
,
const
int
v_head_size
)
{
if
(
0
==
(
qk_head_size
%
4
))
{
const
int
H
=
qk_head_size
/
4
;
const
float4
*
input2
=
reinterpret_cast
<
const
float4
*>
(
input
);
const
float4
*
biases2
=
reinterpret_cast
<
const
float4
*>
(
biases
);
float4
*
output2
=
reinterpret_cast
<
float4
*>
(
output
);
InvokeAddBiasTranspose
<
float4
>
(
stream
,
num_matrices
,
format
,
max_threads_per_block
,
batch_size
,
sequence_length
,
num_heads
,
H
,
input2
,
biases2
,
output2
,
v_head_size
/
4
);
}
else
if
(
0
==
(
qk_head_size
&
1
))
{
const
int
H
=
qk_head_size
/
2
;
const
float2
*
input2
=
reinterpret_cast
<
const
float2
*>
(
input
);
const
float2
*
biases2
=
reinterpret_cast
<
const
float2
*>
(
biases
);
float2
*
output2
=
reinterpret_cast
<
float2
*>
(
output
);
InvokeAddBiasTranspose
<
float2
>
(
stream
,
num_matrices
,
format
,
max_threads_per_block
,
batch_size
,
sequence_length
,
num_heads
,
H
,
input2
,
biases2
,
output2
,
v_head_size
/
2
);
}
else
{
InvokeAddBiasTranspose
<
float
>
(
stream
,
num_matrices
,
format
,
max_threads_per_block
,
batch_size
,
sequence_length
,
num_heads
,
qk_head_size
,
input
,
biases
,
output
,
v_head_size
);
}
}
template
<
typename
T
>
void
InvokeAddBiasTransposeTrt
(
hipStream_t
stream
,
const
int
max_threads_per_block
,
const
int
batch_size
,
const
int
sequence_length
,
const
int
num_heads
,
const
int
head_size
,
const
T
*
biases
,
const
T
*
query
,
const
T
*
key
,
const
T
*
value
,
T
*
output
)
{
constexpr
int
num_matrices
=
3
;
const
dim3
grid
(
sequence_length
,
batch_size
,
num_matrices
);
if
(
head_size
*
num_heads
<=
max_threads_per_block
)
{
const
dim3
block
(
head_size
,
num_heads
,
1
);
hipLaunchKernelGGL
(
HIP_KERNEL_NAME
(
AddBiasTransposeTrt
<
T
>
),
grid
,
block
,
0
,
stream
,
query
,
key
,
value
,
biases
,
output
);
}
else
{
const
dim3
block
(
CeilDiv
(
max_threads_per_block
,
num_heads
),
num_heads
,
1
);
hipLaunchKernelGGL
(
HIP_KERNEL_NAME
(
AddBiasTransposeTrtLarge
<
T
>
),
grid
,
block
,
0
,
stream
,
head_size
,
query
,
key
,
value
,
biases
,
output
);
}
}
template
<
>
void
LaunchAddBiasTransposeTrt
(
hipStream_t
stream
,
const
int
max_threads_per_block
,
const
int
batch_size
,
const
int
sequence_length
,
const
int
num_heads
,
const
int
head_size
,
const
float
*
biases
,
const
float
*
query
,
const
float
*
key
,
const
float
*
value
,
float
*
output
)
{
ORT_ENFORCE
(
false
,
"Shall not call this since fused kernel does not support float input."
);
}
template
<
>
void
LaunchAddBiasTransposeTrt
(
hipStream_t
stream
,
const
int
max_threads_per_block
,
const
int
batch_size
,
const
int
sequence_length
,
const
int
num_heads
,
const
int
head_size
,
const
half
*
biases
,
const
half
*
query
,
const
half
*
key
,
const
half
*
value
,
half
*
output
)
{
if
(
0
==
(
head_size
%
4
))
{
const
int
H
=
head_size
/
4
;
const
Half4
*
query2
=
reinterpret_cast
<
const
Half4
*>
(
query
);
const
Half4
*
key2
=
reinterpret_cast
<
const
Half4
*>
(
key
);
const
Half4
*
value2
=
reinterpret_cast
<
const
Half4
*>
(
value
);
const
Half4
*
biases2
=
reinterpret_cast
<
const
Half4
*>
(
biases
);
Half4
*
output2
=
reinterpret_cast
<
Half4
*>
(
output
);
InvokeAddBiasTransposeTrt
<
Half4
>
(
stream
,
max_threads_per_block
,
batch_size
,
sequence_length
,
num_heads
,
H
,
biases2
,
query2
,
key2
,
value2
,
output2
);
}
else
if
(
0
==
(
head_size
&
1
))
{
const
int
H
=
head_size
/
2
;
const
half2
*
query2
=
reinterpret_cast
<
const
half2
*>
(
query
);
const
half2
*
key2
=
reinterpret_cast
<
const
half2
*>
(
key
);
const
half2
*
value2
=
reinterpret_cast
<
const
half2
*>
(
value
);
const
half2
*
biases2
=
reinterpret_cast
<
const
half2
*>
(
biases
);
half2
*
output2
=
reinterpret_cast
<
half2
*>
(
output
);
InvokeAddBiasTransposeTrt
<
half2
>
(
stream
,
max_threads_per_block
,
batch_size
,
sequence_length
,
num_heads
,
H
,
biases2
,
query2
,
key2
,
value2
,
output2
);
}
else
{
InvokeAddBiasTransposeTrt
<
half
>
(
stream
,
max_threads_per_block
,
batch_size
,
sequence_length
,
num_heads
,
head_size
,
biases
,
query
,
key
,
value
,
output
);
}
}
}
// namespace rocm
}
// namespace contrib
}
// namespace onnxruntime
build/Linux/Release/amdgpu/onnxruntime/contrib_ops/rocm/bert/add_bias_transpose.h
0 → 100644
View file @
1a91fcc2
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include "core/providers/rocm/shared_inc/rocm_utils.h"
namespace
onnxruntime
{
namespace
contrib
{
namespace
rocm
{
// Fused kernel of Add (bias) and Transpose.
// Shape of inputs and outputs:
// biases: (num_matrices, num_heads * head_size)
// format 0:
// input: (num_matrices, batch_size, sequence_length, num_heads, head_size)
// output: (num_matrices, batch_size, num_heads, sequence_length, head_size)
// format 1:
// input : (batch_size, sequence_length, num_matrices, num_heads, head_size)
// output: (num_matrices, batch_size, num_heads, sequence_length, head_size)
// format 2:
// input : (batch_size, sequence_length, num_matrices, num_heads, head_size)
// output: (batch_size, sequence_length, num_heads, num_matrices, head_size)
template
<
typename
T
>
void
LaunchAddBiasTranspose
(
hipStream_t
stream
,
const
int
num_matrices
,
const
int
format
,
const
int
max_threads_per_block
,
const
int
batch_size
,
const
int
sequence_length
,
const
int
num_heads
,
const
int
qk_head_size
,
const
T
*
input
,
const
T
*
biases
,
T
*
output
,
bool
enable_half4
,
const
int
v_head_size
);
// Add (bias) and Transpose for separated inputs of Q, K and V, and output Trt format.
// output: (batch_size, sequence_length, num_heads, num_matrices, head_size)
// It assumes sequence_length == kv_sequence_length and head_size == v_head_size.
template
<
typename
T
>
void
LaunchAddBiasTransposeTrt
(
hipStream_t
stream
,
const
int
max_threads_per_block
,
const
int
batch_size
,
const
int
sequence_length
,
const
int
num_heads
,
const
int
head_size
,
const
T
*
biases
,
const
T
*
query
,
const
T
*
key
,
const
T
*
value
,
T
*
output
);
}
// namespace rocm
}
// namespace contrib
}
// namespace onnxruntime
build/Linux/Release/amdgpu/onnxruntime/contrib_ops/rocm/bert/attention_concat.cu
0 → 100644
View file @
1a91fcc2
#include "hip/hip_runtime.h"
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "core/providers/rocm/rocm_common.h"
#include "contrib_ops/rocm/bert/attention_impl.h"
using
namespace
onnxruntime
::
rocm
;
namespace
onnxruntime
{
namespace
contrib
{
namespace
rocm
{
template
<
typename
T
>
__global__
void
ConcatTensorToTensor
(
const
int
tensor_add_sequence_length
,
const
T
*
tensor_in
,
const
T
*
tensor_add
,
T
*
tensor_out
)
{
const
int
h
=
threadIdx
.
x
;
const
int
n
=
threadIdx
.
y
;
const
int
s
=
blockIdx
.
x
;
const
int
b
=
blockIdx
.
y
;
const
int
chunk_id
=
blockIdx
.
z
;
const
int
all_sequence_length
=
gridDim
.
x
;
const
int
batch_size
=
gridDim
.
y
;
const
int
num_heads
=
blockDim
.
y
;
const
int
H
=
blockDim
.
x
;
// K: number of identical tensors
// tensor_in: K x BxNxPxH
// tensor_add: K x BxNxLxH
// tensor_out: K x BxNxTxH, where T = P + L
const
int
tensor_in_sequence_length
=
all_sequence_length
-
tensor_add_sequence_length
;
const
int
present_SH
=
all_sequence_length
*
H
;
const
int
present_NSH
=
num_heads
*
present_SH
;
int
out_offset
=
b
*
present_NSH
+
n
*
present_SH
+
s
*
H
+
h
+
chunk_id
*
(
present_NSH
*
batch_size
);
if
(
s
<
tensor_in_sequence_length
)
{
const
int
past_SH
=
tensor_in_sequence_length
*
H
;
const
int
past_NSH
=
num_heads
*
past_SH
;
const
int
in_offset
=
b
*
past_NSH
+
n
*
past_SH
+
s
*
H
+
h
+
chunk_id
*
(
past_NSH
*
batch_size
);
tensor_out
[
out_offset
]
=
tensor_in
[
in_offset
];
}
else
if
(
s
<
all_sequence_length
)
{
const
int
SH
=
tensor_add_sequence_length
*
H
;
const
int
NSH
=
num_heads
*
SH
;
const
int
in_offset
=
b
*
NSH
+
n
*
SH
+
(
s
-
tensor_in_sequence_length
)
*
H
+
h
+
chunk_id
*
(
NSH
*
batch_size
);
tensor_out
[
out_offset
]
=
tensor_add
[
in_offset
];
}
}
template
<
typename
T
>
__global__
void
ConcatTensorToTensorLarge
(
const
int
tensor_add_sequence_length
,
const
int
H
,
const
T
*
tensor_in
,
const
T
*
tensor_add
,
T
*
tensor_out
)
{
// Use when (H*)*num_heads > 1024
int
h
=
threadIdx
.
x
;
const
int
n
=
threadIdx
.
y
;
const
int
s
=
blockIdx
.
x
;
const
int
b
=
blockIdx
.
y
;
const
int
chunk_id
=
blockIdx
.
z
;
const
int
all_sequence_length
=
gridDim
.
x
;
const
int
batch_size
=
gridDim
.
y
;
const
int
num_heads
=
blockDim
.
y
;
const
int
stride
=
blockDim
.
x
;
// K: number of identical tensor
// tensor_in: K x BxNxPxH
// tensor_add: K x BxNxLxH
// tensor_out: K x BxNxTxH
const
int
tensor_in_sequence_length
=
all_sequence_length
-
tensor_add_sequence_length
;
const
int
present_SH
=
all_sequence_length
*
H
;
const
int
present_NSH
=
num_heads
*
present_SH
;
while
(
h
<
H
)
{
int
out_offset
=
b
*
present_NSH
+
n
*
present_SH
+
s
*
H
+
h
+
chunk_id
*
(
present_NSH
*
batch_size
);
if
(
s
<
tensor_in_sequence_length
)
{
const
int
past_SH
=
tensor_in_sequence_length
*
H
;
const
int
past_NSH
=
num_heads
*
past_SH
;
const
int
in_offset
=
b
*
past_NSH
+
n
*
past_SH
+
s
*
H
+
h
+
chunk_id
*
(
past_NSH
*
batch_size
);
tensor_out
[
out_offset
]
=
tensor_in
[
in_offset
];
}
else
if
(
s
<
all_sequence_length
)
{
const
int
SH
=
tensor_add_sequence_length
*
H
;
const
int
NSH
=
num_heads
*
SH
;
const
int
in_offset
=
b
*
NSH
+
n
*
SH
+
(
s
-
tensor_in_sequence_length
)
*
H
+
h
+
chunk_id
*
(
NSH
*
batch_size
);
tensor_out
[
out_offset
]
=
tensor_add
[
in_offset
];
}
h
+=
stride
;
}
}
Status
LaunchConcatTensorToTensor
(
hipStream_t
stream
,
const
int
all_sequence_length
,
const
int
sequence_length
,
const
int
batch_size
,
const
int
head_size
,
const
int
num_heads
,
const
int
max_threads_per_block
,
const
int
matrix_num
,
const
float
*
tensor_in
,
const
float
*
tensor_add
,
float
*
tensor_out
)
{
const
dim3
grid
(
all_sequence_length
,
batch_size
,
matrix_num
);
if
(
0
==
(
head_size
&
1
))
{
const
int
H
=
head_size
/
2
;
if
(
H
*
num_heads
<=
max_threads_per_block
)
{
const
dim3
block
(
H
,
num_heads
,
1
);
hipLaunchKernelGGL
(
HIP_KERNEL_NAME
(
ConcatTensorToTensor
<
float2
>
),
grid
,
block
,
0
,
stream
,
sequence_length
,
reinterpret_cast
<
const
float2
*>
(
tensor_in
),
reinterpret_cast
<
const
float2
*>
(
tensor_add
),
reinterpret_cast
<
float2
*>
(
tensor_out
));
}
else
{
const
dim3
block
(
max_threads_per_block
/
num_heads
,
num_heads
,
1
);
hipLaunchKernelGGL
(
HIP_KERNEL_NAME
(
ConcatTensorToTensorLarge
<
float2
>
),
grid
,
block
,
0
,
stream
,
sequence_length
,
H
,
reinterpret_cast
<
const
float2
*>
(
tensor_in
),
reinterpret_cast
<
const
float2
*>
(
tensor_add
),
reinterpret_cast
<
float2
*>
(
tensor_out
));
}
}
else
{
if
(
head_size
*
num_heads
<=
max_threads_per_block
)
{
const
dim3
block
(
head_size
,
num_heads
,
1
);
hipLaunchKernelGGL
(
HIP_KERNEL_NAME
(
ConcatTensorToTensor
<
float
>
),
grid
,
block
,
0
,
stream
,
sequence_length
,
tensor_in
,
tensor_add
,
tensor_out
);
}
else
{
const
dim3
block
(
max_threads_per_block
/
num_heads
,
num_heads
,
1
);
hipLaunchKernelGGL
(
HIP_KERNEL_NAME
(
ConcatTensorToTensorLarge
<
float
>
),
grid
,
block
,
0
,
stream
,
sequence_length
,
head_size
,
tensor_in
,
tensor_add
,
tensor_out
);
}
}
return
HIP_CALL
(
hipGetLastError
());
}
Status
LaunchConcatTensorToTensor
(
hipStream_t
stream
,
const
int
all_sequence_length
,
const
int
sequence_length
,
const
int
batch_size
,
const
int
head_size
,
const
int
num_heads
,
const
int
max_threads_per_block
,
const
int
matrix_num
,
const
half
*
tensor_in
,
const
half
*
tensor_add
,
half
*
tensor_out
)
{
const
dim3
grid
(
all_sequence_length
,
batch_size
,
matrix_num
);
if
(
0
==
(
head_size
%
4
))
{
const
int
H
=
head_size
/
4
;
if
(
H
*
num_heads
<=
max_threads_per_block
)
{
const
dim3
block
(
H
,
num_heads
,
1
);
hipLaunchKernelGGL
(
HIP_KERNEL_NAME
(
ConcatTensorToTensor
<
float2
>
),
grid
,
block
,
0
,
stream
,
sequence_length
,
reinterpret_cast
<
const
float2
*>
(
tensor_in
),
reinterpret_cast
<
const
float2
*>
(
tensor_add
),
reinterpret_cast
<
float2
*>
(
tensor_out
));
}
else
{
const
dim3
block
(
max_threads_per_block
/
num_heads
,
num_heads
,
1
);
hipLaunchKernelGGL
(
HIP_KERNEL_NAME
(
ConcatTensorToTensorLarge
<
float2
>
),
grid
,
block
,
0
,
stream
,
sequence_length
,
H
,
reinterpret_cast
<
const
float2
*>
(
tensor_in
),
reinterpret_cast
<
const
float2
*>
(
tensor_add
),
reinterpret_cast
<
float2
*>
(
tensor_out
));
}
}
else
if
(
0
==
(
head_size
&
1
))
{
const
int
H
=
head_size
/
2
;
if
(
H
*
num_heads
<=
max_threads_per_block
)
{
const
dim3
block
(
H
,
num_heads
,
1
);
hipLaunchKernelGGL
(
HIP_KERNEL_NAME
(
ConcatTensorToTensor
<
half2
>
),
grid
,
block
,
0
,
stream
,
sequence_length
,
reinterpret_cast
<
const
half2
*>
(
tensor_in
),
reinterpret_cast
<
const
half2
*>
(
tensor_add
),
reinterpret_cast
<
half2
*>
(
tensor_out
));
}
else
{
const
dim3
block
(
max_threads_per_block
/
num_heads
,
num_heads
,
1
);
hipLaunchKernelGGL
(
HIP_KERNEL_NAME
(
ConcatTensorToTensorLarge
<
half2
>
),
grid
,
block
,
0
,
stream
,
sequence_length
,
H
,
reinterpret_cast
<
const
half2
*>
(
tensor_in
),
reinterpret_cast
<
const
half2
*>
(
tensor_add
),
reinterpret_cast
<
half2
*>
(
tensor_out
));
}
}
else
{
// this should be an "odd" case. probably not worth catching it in the half2 kernel.
if
(
head_size
*
num_heads
<=
max_threads_per_block
)
{
const
dim3
block
(
head_size
,
num_heads
,
1
);
hipLaunchKernelGGL
(
HIP_KERNEL_NAME
(
ConcatTensorToTensor
<
half
>
),
grid
,
block
,
0
,
stream
,
sequence_length
,
tensor_in
,
tensor_add
,
tensor_out
);
}
else
{
const
dim3
block
(
max_threads_per_block
/
num_heads
,
num_heads
,
1
);
hipLaunchKernelGGL
(
HIP_KERNEL_NAME
(
ConcatTensorToTensorLarge
<
half
>
),
grid
,
block
,
0
,
stream
,
sequence_length
,
head_size
,
tensor_in
,
tensor_add
,
tensor_out
);
}
}
return
HIP_CALL
(
hipGetLastError
());
}
Status
LaunchConcatPastToPresent
(
hipStream_t
stream
,
const
int
all_sequence_length
,
const
int
sequence_length
,
const
int
batch_size
,
const
int
head_size
,
const
int
num_heads
,
const
int
max_threads_per_block
,
const
float
*
past
,
const
float
*
k_v
,
float
*
present
)
{
return
LaunchConcatTensorToTensor
(
stream
,
all_sequence_length
,
sequence_length
,
batch_size
,
head_size
,
num_heads
,
max_threads_per_block
,
2
,
past
,
k_v
,
present
);
}
Status
LaunchConcatPastToPresent
(
hipStream_t
stream
,
const
int
all_sequence_length
,
const
int
sequence_length
,
const
int
batch_size
,
const
int
head_size
,
const
int
num_heads
,
const
int
max_threads_per_block
,
const
half
*
past
,
const
half
*
k_v
,
half
*
present
)
{
return
LaunchConcatTensorToTensor
(
stream
,
all_sequence_length
,
sequence_length
,
batch_size
,
head_size
,
num_heads
,
max_threads_per_block
,
2
,
past
,
k_v
,
present
);
}
}
// namespace rocm
}
// namespace contrib
}
// namespace onnxruntime
build/Linux/Release/amdgpu/onnxruntime/contrib_ops/rocm/bert/attention_impl.h
0 → 100644
View file @
1a91fcc2
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include "core/providers/rocm/shared_inc/rocm_utils.h"
#include <hip/hip_fp16.h>
#include <rocblas.h>
#include "contrib_ops/cpu/bert/attention_common.h"
namespace
onnxruntime
{
namespace
contrib
{
namespace
rocm
{
size_t
GetAttentionScratchSize
(
size_t
element_size
,
size_t
batch_size
,
size_t
num_heads
,
size_t
sequence_length
,
size_t
all_sequence_length
);
size_t
GetAttentionWorkspaceSize
(
size_t
element_size
,
size_t
batchsize
,
size_t
num_heads
,
size_t
qk_head_size
,
size_t
v_head_size
,
size_t
sequence_length
,
size_t
kv_sequence_length
,
size_t
total_sequence_length
,
void
*
fused_runner
);
template
<
typename
T
>
struct
AttentionData
{
const
T
*
gemm_buffer
;
const
T
*
bias
;
const
T
*
query
;
const
T
*
key
;
const
T
*
value
;
const
int
*
mask_index
;
gsl
::
span
<
const
int64_t
>
mask_index_dims
;
const
T
*
past
;
const
T
*
extra_add_qk
;
T
*
workspace
;
T
*
output
;
T
*
present
;
};
template
<
typename
T
>
Status
QkvToContext
(
const
hipDeviceProp_t
&
prop
,
rocblas_handle
&
rocblas
,
hipStream_t
stream
,
contrib
::
AttentionParameters
&
parameters
,
AttentionData
<
T
>&
data
,
void
*
fused_runner
);
Status
LaunchDecoderAttentionKernel
(
const
hipDeviceProp_t
&
prop
,
// Device Properties
hipStream_t
stream
,
// Cuda stream
rocblas_handle
&
rocblas
,
// Rocblas handle
const
size_t
element_size
,
// Element size of input tensor
const
int
batch_size
,
// Batch size (B)
const
int
sequence_length
,
// Sequence length (S)
const
int
kv_sequence_length
,
// Key/Value/Cache sequence length
const
int
num_heads
,
// Number of attention heads (N)
const
int
head_size
,
// Hidden size per head (H)
const
bool
static_kv
,
// Whether cross attention or not
const
bool
use_past
,
// Whether use cache or not
const
bool
has_layer_state
,
// Whether output cache or not
const
bool
has_key_padding_mask
,
// Whether use key_padding_mask or not
const
void
*
gemm_query_buffer
,
// Query buffer
const
void
*
gemm_kv_buffer
,
// Key and value buffer
const
bool
*
key_padding_mask
,
// Key padding mask
const
void
*
key_cache
,
// Input key cache
const
void
*
value_cache
,
// Input value cache
void
*
qkv_buffer
,
// Temporary buffer
void
*
workspace_buffer
,
// Temporary buffer
void
*
output
,
// Output tensor
void
*
new_key_cache
,
// New_key_cache tensor
void
*
new_value_cache
// New_value_cache tensor
);
// BxNxSxH => BxSxNxH or SxBxNxH (reversed_bs is true)
Status
LaunchTransCtx
(
hipStream_t
stream
,
const
int
sequence_length
,
const
int
batch_size
,
const
int
head_size
,
const
int
num_heads
,
const
int
max_threads_per_block
,
const
bool
reversed_bs
,
const
float
*
input
,
float
*
output
);
Status
LaunchTransCtx
(
hipStream_t
stream
,
const
int
sequence_length
,
const
int
batch_size
,
const
int
head_size
,
const
int
num_heads
,
const
int
max_threads_per_block
,
const
bool
reversed_bs
,
const
half
*
input
,
half
*
output
);
// BxSxMxNxH or SxBxMxNxH (reversed_bs is true) => MxBxNxSxH
Status
LaunchTransQkv
(
hipStream_t
stream
,
const
int
matrix_num
,
const
int
sequence_length
,
const
int
batch_size
,
const
int
head_size
,
const
int
num_heads
,
const
int
max_threads_per_block
,
const
bool
reversed_bs
,
const
float
*
input
,
float
*
output
);
Status
LaunchTransQkv
(
hipStream_t
stream
,
const
int
matrix_num
,
const
int
sequence_length
,
const
int
batch_size
,
const
int
head_size
,
const
int
num_heads
,
const
int
max_threads_per_block
,
const
bool
reversed_bs
,
const
half
*
input
,
half
*
output
);
Status
LaunchConcatTensorToTensor
(
hipStream_t
stream
,
const
int
all_sequence_length
,
const
int
sequence_length
,
const
int
batch_size
,
const
int
head_size
,
const
int
num_heads
,
const
int
max_threads_per_block
,
const
int
matrix_num
,
const
float
*
tensor_in
,
const
float
*
tensor_add
,
float
*
tensor_out
);
Status
LaunchConcatTensorToTensor
(
hipStream_t
stream
,
const
int
all_sequence_length
,
const
int
sequence_length
,
const
int
batch_size
,
const
int
head_size
,
const
int
num_heads
,
const
int
max_threads_per_block
,
const
int
matrix_num
,
const
half
*
tensor_in
,
const
half
*
tensor_add
,
half
*
tensor_out
);
Status
LaunchConcatPastToPresent
(
hipStream_t
stream
,
const
int
all_sequence_length
,
const
int
sequence_length
,
const
int
batch_size
,
const
int
head_size
,
const
int
num_heads
,
const
int
max_threads_per_block
,
const
float
*
past
,
const
float
*
k_v
,
float
*
present
);
Status
LaunchConcatPastToPresent
(
hipStream_t
stream
,
const
int
all_sequence_length
,
const
int
sequence_length
,
const
int
batch_size
,
const
int
head_size
,
const
int
num_heads
,
const
int
max_threads_per_block
,
const
half
*
past
,
const
half
*
k_v
,
half
*
present
);
}
// namespace rocm
}
// namespace contrib
}
// namespace onnxruntime
build/Linux/Release/amdgpu/onnxruntime/contrib_ops/rocm/bert/attention_transpose.cu
0 → 100644
View file @
1a91fcc2
#include "hip/hip_runtime.h"
/*
The implementation of this file is based on qkvToContext plugin in TensorRT demo:
https://github.com/NVIDIA/TensorRT/tree/release/5.1/demo/BERT/
Copyright 2019 NVIDIA Corporation
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
// Modifications: add transpose kernels for TRT format
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "core/providers/rocm/rocm_common.h"
#include "contrib_ops/rocm/bert/attention_impl.h"
using
namespace
onnxruntime
::
rocm
;
namespace
onnxruntime
{
namespace
contrib
{
namespace
rocm
{
template
<
typename
T
>
__global__
void
TransposeCtx
(
const
int
H
,
const
bool
reversed_bs
,
const
T
*
input
,
T
*
output
)
{
// Input: BxNxSxH
// Output: BxSxNxH
int
n
=
threadIdx
.
y
;
int
s
=
blockIdx
.
x
;
int
b
=
blockIdx
.
y
;
int
num_heads
=
blockDim
.
y
;
int
sequence_length
=
gridDim
.
x
;
const
int
NH
=
num_heads
*
H
;
const
int
NHS
=
NH
*
sequence_length
;
const
int
in_offset
=
s
*
H
+
n
*
sequence_length
*
H
+
b
*
NHS
;
int
out_offset
=
0
;
if
(
reversed_bs
)
{
const
int
batch_size
=
gridDim
.
y
;
const
int
BNH
=
NH
*
batch_size
;
out_offset
=
n
*
H
+
b
*
NH
+
s
*
BNH
;
}
else
{
out_offset
=
n
*
H
+
s
*
NH
+
b
*
NHS
;
}
const
int
i
=
threadIdx
.
x
;
if
(
i
<
H
)
{
output
[
out_offset
+
i
]
=
input
[
in_offset
+
i
];
}
}
template
<
typename
T
>
__global__
void
TransposeCtxLarge
(
const
int
H
,
const
bool
reversed_bs
,
const
T
*
input
,
T
*
output
)
{
// Use when (H*)*num_heads > 1024
// Input: BxNxSxH
// Output: BxSxNxH
int
n
=
threadIdx
.
y
;
int
s
=
blockIdx
.
x
;
int
b
=
blockIdx
.
y
;
int
stride
=
blockDim
.
x
;
int
num_heads
=
blockDim
.
y
;
int
sequence_length
=
gridDim
.
x
;
const
int
NH
=
num_heads
*
H
;
const
int
NHS
=
NH
*
sequence_length
;
const
int
in_offset
=
s
*
H
+
n
*
sequence_length
*
H
+
b
*
NHS
;
int
out_offset
=
0
;
if
(
reversed_bs
)
{
const
int
batch_size
=
gridDim
.
y
;
const
int
BNH
=
NH
*
batch_size
;
out_offset
=
n
*
H
+
b
*
NH
+
s
*
BNH
;
}
else
{
out_offset
=
n
*
H
+
s
*
NH
+
b
*
NHS
;
}
int
i
=
threadIdx
.
x
;
while
(
i
<
H
)
{
output
[
out_offset
+
i
]
=
input
[
in_offset
+
i
];
i
+=
stride
;
}
}
Status
LaunchTransCtx
(
hipStream_t
stream
,
const
int
sequence_length
,
const
int
batch_size
,
const
int
head_size
,
const
int
num_heads
,
const
int
max_threads_per_block
,
const
bool
reversed_bs
,
const
float
*
input
,
float
*
output
)
{
const
dim3
grid
(
sequence_length
,
batch_size
,
1
);
if
(
0
==
(
head_size
&
1
))
{
const
int
H
=
head_size
/
2
;
const
float2
*
input2
=
reinterpret_cast
<
const
float2
*>
(
input
);
float2
*
output2
=
reinterpret_cast
<
float2
*>
(
output
);
if
(
H
*
num_heads
<=
max_threads_per_block
)
{
const
dim3
block
(
H
,
num_heads
,
1
);
hipLaunchKernelGGL
(
HIP_KERNEL_NAME
(
TransposeCtx
<
float2
>
),
grid
,
block
,
0
,
stream
,
H
,
reversed_bs
,
input2
,
output2
);
}
else
{
const
dim3
block
(
max_threads_per_block
/
num_heads
,
num_heads
,
1
);
hipLaunchKernelGGL
(
HIP_KERNEL_NAME
(
TransposeCtxLarge
<
float2
>
),
grid
,
block
,
0
,
stream
,
H
,
reversed_bs
,
input2
,
output2
);
}
}
else
{
if
(
head_size
*
num_heads
<=
max_threads_per_block
)
{
const
dim3
block
(
head_size
,
num_heads
,
1
);
hipLaunchKernelGGL
(
HIP_KERNEL_NAME
(
TransposeCtx
<
float
>
),
grid
,
block
,
0
,
stream
,
head_size
,
reversed_bs
,
input
,
output
);
}
else
{
const
dim3
block
(
max_threads_per_block
/
num_heads
,
num_heads
,
1
);
hipLaunchKernelGGL
(
HIP_KERNEL_NAME
(
TransposeCtxLarge
<
float
>
),
grid
,
block
,
0
,
stream
,
head_size
,
reversed_bs
,
input
,
output
);
}
}
return
HIP_CALL
(
hipGetLastError
());
}
Status
LaunchTransCtx
(
hipStream_t
stream
,
const
int
sequence_length
,
const
int
batch_size
,
const
int
head_size
,
const
int
num_heads
,
const
int
max_threads_per_block
,
const
bool
reversed_bs
,
const
half
*
input
,
half
*
output
)
{
const
dim3
grid
(
sequence_length
,
batch_size
,
1
);
if
(
0
==
(
head_size
%
4
))
{
const
int
H
=
head_size
/
4
;
const
float2
*
input2
=
reinterpret_cast
<
const
float2
*>
(
input
);
float2
*
output2
=
reinterpret_cast
<
float2
*>
(
output
);
if
(
H
*
num_heads
<=
max_threads_per_block
)
{
const
dim3
block
(
H
,
num_heads
,
1
);
hipLaunchKernelGGL
(
HIP_KERNEL_NAME
(
TransposeCtx
<
float2
>
),
grid
,
block
,
0
,
stream
,
H
,
reversed_bs
,
input2
,
output2
);
}
else
{
const
dim3
block
(
max_threads_per_block
/
num_heads
,
num_heads
,
1
);
hipLaunchKernelGGL
(
HIP_KERNEL_NAME
(
TransposeCtxLarge
<
float2
>
),
grid
,
block
,
0
,
stream
,
H
,
reversed_bs
,
input2
,
output2
);
}
}
else
if
(
0
==
(
head_size
&
1
))
{
const
int
H
=
head_size
/
2
;
const
half2
*
input2
=
reinterpret_cast
<
const
half2
*>
(
input
);
half2
*
output2
=
reinterpret_cast
<
half2
*>
(
output
);
if
(
H
*
num_heads
<=
max_threads_per_block
)
{
const
dim3
block
(
H
,
num_heads
,
1
);
hipLaunchKernelGGL
(
HIP_KERNEL_NAME
(
TransposeCtx
<
half2
>
),
grid
,
block
,
0
,
stream
,
H
,
reversed_bs
,
input2
,
output2
);
}
else
{
const
dim3
block
(
max_threads_per_block
/
num_heads
,
num_heads
,
1
);
hipLaunchKernelGGL
(
HIP_KERNEL_NAME
(
TransposeCtxLarge
<
half2
>
),
grid
,
block
,
0
,
stream
,
H
,
reversed_bs
,
input2
,
output2
);
}
}
else
{
// this should be an "odd" case. probably not worth catching it in the half2 kernel.
if
(
head_size
*
num_heads
<=
max_threads_per_block
)
{
const
dim3
block
(
head_size
,
num_heads
,
1
);
hipLaunchKernelGGL
(
HIP_KERNEL_NAME
(
TransposeCtx
<
half
>
),
grid
,
block
,
0
,
stream
,
head_size
,
reversed_bs
,
input
,
output
);
}
else
{
const
dim3
block
(
max_threads_per_block
/
num_heads
,
num_heads
,
1
);
hipLaunchKernelGGL
(
HIP_KERNEL_NAME
(
TransposeCtxLarge
<
half
>
),
grid
,
block
,
0
,
stream
,
head_size
,
reversed_bs
,
input
,
output
);
}
}
return
HIP_CALL
(
hipGetLastError
());
}
template
<
typename
T
>
__global__
void
TransposeQKV
(
const
int
H
,
const
bool
reversed_bs
,
const
T
*
input
,
T
*
output
)
{
// Input: BxSxKxNxH or SxBxKxNxH
// Output: KxBxNxSxH
// K is the number of identical matrix
int
n
=
threadIdx
.
y
;
int
s
=
blockIdx
.
x
;
int
b
=
blockIdx
.
y
;
int
m
=
blockIdx
.
z
;
// matrix id
const
int
num_heads
=
blockDim
.
y
;
const
int
sequence_length
=
gridDim
.
x
;
const
int
batch_size
=
gridDim
.
y
;
const
int
chunk_num
=
gridDim
.
z
;
const
int
NH
=
num_heads
*
H
;
const
int
NHS
=
NH
*
sequence_length
;
int
in_offset
=
0
;
if
(
reversed_bs
)
{
const
int
BNH
=
NH
*
batch_size
;
in_offset
=
n
*
H
+
(
m
+
b
*
chunk_num
)
*
NH
+
s
*
BNH
*
chunk_num
;
}
else
{
in_offset
=
n
*
H
+
(
m
+
s
*
chunk_num
)
*
NH
+
b
*
NHS
*
chunk_num
;
}
const
int
out_offset
=
s
*
H
+
n
*
sequence_length
*
H
+
b
*
NHS
+
m
*
NHS
*
batch_size
;
const
int
i
=
threadIdx
.
x
;
if
(
i
<
H
)
{
output
[
out_offset
+
i
]
=
input
[
in_offset
+
i
];
}
}
template
<
typename
T
>
__global__
void
TransposeQKVLarge
(
const
int
H
,
const
bool
reversed_bs
,
const
T
*
input
,
T
*
output
)
{
// Use when (H*)*num_heads > 1024
// Input: BxSxKxNxH or SxBxKxNxH
// Output: KxBxNxSxH
// K is the number of identical matrix
int
n
=
threadIdx
.
y
;
int
s
=
blockIdx
.
x
;
int
b
=
blockIdx
.
y
;
int
m
=
blockIdx
.
z
;
// matrix id
const
int
stride
=
blockDim
.
x
;
const
int
num_heads
=
blockDim
.
y
;
const
int
sequence_length
=
gridDim
.
x
;
const
int
batch_size
=
gridDim
.
y
;
const
int
chunk_num
=
gridDim
.
z
;
const
int
NH
=
num_heads
*
H
;
const
int
NHS
=
NH
*
sequence_length
;
int
in_offset
=
0
;
if
(
reversed_bs
)
{
const
int
BNH
=
NH
*
batch_size
;
in_offset
=
n
*
H
+
(
m
+
b
*
chunk_num
)
*
NH
+
s
*
BNH
*
chunk_num
;
}
else
{
in_offset
=
n
*
H
+
(
m
+
s
*
chunk_num
)
*
NH
+
b
*
NHS
*
chunk_num
;
}
const
int
out_offset
=
s
*
H
+
n
*
sequence_length
*
H
+
b
*
NHS
+
m
*
NHS
*
batch_size
;
int
i
=
threadIdx
.
x
;
while
(
i
<
H
)
{
output
[
out_offset
+
i
]
=
input
[
in_offset
+
i
];
i
+=
stride
;
}
}
Status
LaunchTransQkv
(
hipStream_t
stream
,
const
int
matrix_num
,
const
int
sequence_length
,
const
int
batch_size
,
const
int
head_size
,
const
int
num_heads
,
const
int
max_threads_per_block
,
const
bool
reversed_bs
,
const
float
*
input
,
float
*
output
)
{
const
dim3
grid
(
sequence_length
,
batch_size
,
matrix_num
);
if
(
0
==
(
head_size
&
1
))
{
const
int
H
=
head_size
/
2
;
const
float2
*
input2
=
reinterpret_cast
<
const
float2
*>
(
input
);
float2
*
output2
=
reinterpret_cast
<
float2
*>
(
output
);
if
(
H
*
num_heads
<=
max_threads_per_block
)
{
const
dim3
block
(
H
,
num_heads
,
1
);
hipLaunchKernelGGL
(
HIP_KERNEL_NAME
(
TransposeQKV
<
float2
>
),
grid
,
block
,
0
,
stream
,
H
,
reversed_bs
,
input2
,
output2
);
}
else
{
const
dim3
block
(
max_threads_per_block
/
num_heads
,
num_heads
,
1
);
hipLaunchKernelGGL
(
HIP_KERNEL_NAME
(
TransposeQKVLarge
<
float2
>
),
grid
,
block
,
0
,
stream
,
H
,
reversed_bs
,
input2
,
output2
);
}
}
else
{
if
(
head_size
*
num_heads
<=
max_threads_per_block
)
{
const
dim3
block
(
head_size
,
num_heads
,
1
);
hipLaunchKernelGGL
(
HIP_KERNEL_NAME
(
TransposeQKV
<
float
>
),
grid
,
block
,
0
,
stream
,
head_size
,
reversed_bs
,
input
,
output
);
}
else
{
const
dim3
block
(
max_threads_per_block
/
num_heads
,
num_heads
,
1
);
hipLaunchKernelGGL
(
HIP_KERNEL_NAME
(
TransposeQKVLarge
<
float
>
),
grid
,
block
,
0
,
stream
,
head_size
,
reversed_bs
,
input
,
output
);
}
}
return
HIP_CALL
(
hipGetLastError
());
}
Status
LaunchTransQkv
(
hipStream_t
stream
,
const
int
matrix_num
,
const
int
sequence_length
,
const
int
batch_size
,
const
int
head_size
,
const
int
num_heads
,
const
int
max_threads_per_block
,
const
bool
reversed_bs
,
const
half
*
input
,
half
*
output
)
{
const
dim3
grid
(
sequence_length
,
batch_size
,
matrix_num
);
if
(
0
==
(
head_size
%
4
))
{
const
int
H
=
head_size
/
4
;
const
float2
*
input2
=
reinterpret_cast
<
const
float2
*>
(
input
);
float2
*
output2
=
reinterpret_cast
<
float2
*>
(
output
);
if
(
H
*
num_heads
<=
max_threads_per_block
)
{
const
dim3
block
(
H
,
num_heads
,
1
);
hipLaunchKernelGGL
(
HIP_KERNEL_NAME
(
TransposeQKV
<
float2
>
),
grid
,
block
,
0
,
stream
,
H
,
reversed_bs
,
input2
,
output2
);
}
else
{
const
dim3
block
(
max_threads_per_block
/
num_heads
,
num_heads
,
1
);
hipLaunchKernelGGL
(
HIP_KERNEL_NAME
(
TransposeQKVLarge
<
float2
>
),
grid
,
block
,
0
,
stream
,
H
,
reversed_bs
,
input2
,
output2
);
}
}
else
if
(
0
==
(
head_size
&
1
))
{
const
int
H
=
head_size
/
2
;
const
half2
*
input2
=
reinterpret_cast
<
const
half2
*>
(
input
);
half2
*
output2
=
reinterpret_cast
<
half2
*>
(
output
);
if
(
H
*
num_heads
<=
max_threads_per_block
)
{
const
dim3
block
(
H
,
num_heads
,
1
);
hipLaunchKernelGGL
(
HIP_KERNEL_NAME
(
TransposeQKV
<
half2
>
),
grid
,
block
,
0
,
stream
,
H
,
reversed_bs
,
input2
,
output2
);
}
else
{
const
dim3
block
(
max_threads_per_block
/
num_heads
,
num_heads
,
1
);
hipLaunchKernelGGL
(
HIP_KERNEL_NAME
(
TransposeQKVLarge
<
half2
>
),
grid
,
block
,
0
,
stream
,
H
,
reversed_bs
,
input2
,
output2
);
}
}
else
{
// this should be an "odd" case. probably not worth catching it in the half2 kernel..
if
(
head_size
*
num_heads
<=
max_threads_per_block
)
{
const
dim3
block
(
head_size
,
num_heads
,
1
);
hipLaunchKernelGGL
(
HIP_KERNEL_NAME
(
TransposeQKV
<
half
>
),
grid
,
block
,
0
,
stream
,
head_size
,
reversed_bs
,
input
,
output
);
}
else
{
const
dim3
block
(
max_threads_per_block
/
num_heads
,
num_heads
,
1
);
hipLaunchKernelGGL
(
HIP_KERNEL_NAME
(
TransposeQKVLarge
<
half
>
),
grid
,
block
,
0
,
stream
,
head_size
,
reversed_bs
,
input
,
output
);
}
}
return
HIP_CALL
(
hipGetLastError
());
}
}
// namespace rocm
}
// namespace contrib
}
// namespace onnxruntime
build/Linux/Release/amdgpu/onnxruntime/contrib_ops/rocm/bert/bert_padding.cu
0 → 100644
View file @
1a91fcc2
#include "hip/hip_runtime.h"
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
// TrtSequenceOffset kernels are modified from FasterTransformer
/*
* Copyright (c) 2019-2021, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "core/providers/rocm/rocm_common.h"
#include "contrib_ops/rocm/bert/bert_padding.h"
using
namespace
onnxruntime
::
rocm
;
namespace
onnxruntime
{
namespace
contrib
{
namespace
rocm
{
constexpr
int32_t
kMAX_THREADS_PER_BLOCK
=
256
;
// -----------------------------------
// Get indices of non-padding tokens and padding tokens. Here we assume that padding is on the right side of sequence.
// sequence_token_count is number of non-padding tokens per sequence, and it has shape [batch_size].
// For example, we have 3 sequences with 1, 2, 4 non-padding tokens and positions like the following (* means padding):
// Sequence_0: 0, 1*, 2*, 3*
// Sequence_1: 4, 5, 6*, 7*
// Sequence_2: 8, 9, 10, 11
// token_offset: 0, 4, 5, 8, 9, 10, 11, 1*, 2*, 3*, 6*, 7*
// token_count_buffer has two numbers for non-padding tokens:
// total_token_count: 1 + 2 + 4 = 7
// max_token_count: 4
// cumulated_token_count: 0, 1, 1+2, 1+2+4
__global__
void
getTokenOffset
(
int
*
token_count_buffer
,
int
*
token_offset
,
int
*
cumulated_token_count
,
const
int
*
sequence_token_count
,
const
int
batch_size
,
const
int
sequence_length
)
{
// Find offset of non-padding tokens, and max sequence length among all batches
// TODO(tianleiwu): Use hipcub::DevicePartition::Flagged like BuildGlobalIndex in longformer_global_impl.cu
// to build token_offset when sequence length is large.
int
total_tokens
=
0
;
int
max_tokens
=
0
;
int
index
=
0
;
cumulated_token_count
[
0
]
=
0
;
for
(
int
i
=
0
;
i
<
batch_size
;
i
++
)
{
const
int
count
=
sequence_token_count
[
i
];
if
(
count
>
max_tokens
)
{
max_tokens
=
count
;
}
cumulated_token_count
[
i
+
1
]
=
cumulated_token_count
[
i
]
+
count
;
for
(
int
j
=
0
;
j
<
count
;
j
++
)
{
token_offset
[
index
]
=
i
*
sequence_length
+
j
;
index
++
;
}
total_tokens
+=
count
;
}
// Offset of paddings
for
(
int
i
=
0
;
i
<
batch_size
;
i
++
)
{
const
int
count
=
sequence_token_count
[
i
];
for
(
int
j
=
0
;
j
<
sequence_length
-
count
;
j
++
)
{
token_offset
[
index
]
=
i
*
sequence_length
+
count
+
j
;
index
++
;
}
}
token_count_buffer
[
0
]
=
total_tokens
;
token_count_buffer
[
1
]
=
max_tokens
;
}
void
LaunchGetTokenOffset
(
int
*
token_count_buffer
,
int
*
token_offset
,
int
*
cumulated_token_count
,
const
int
*
sequence_token_count
,
const
int
batch_size
,
const
int
sequence_length
,
hipStream_t
stream
)
{
hipLaunchKernelGGL
(
getTokenOffset
,
1
,
1
,
0
,
stream
,
token_count_buffer
,
token_offset
,
cumulated_token_count
,
sequence_token_count
,
batch_size
,
sequence_length
);
}
// -----------------------------------
// Remove paddings
template
<
typename
T
>
__global__
void
__launch_bounds__
(
kMAX_THREADS_PER_BLOCK
)
removePadding
(
T
*
target
,
const
T
*
source
,
const
int
*
token_offset
,
const
int
width
)
{
const
int
tid
=
threadIdx
.
x
;
const
int
token_index
=
blockIdx
.
x
;
const
int
source_offset
=
token_offset
[
token_index
];
const
int
target_offset
=
token_index
;
for
(
int
i
=
tid
;
i
<
width
;
i
+=
blockDim
.
x
)
{
target
[
target_offset
*
width
+
i
]
=
source
[
source_offset
*
width
+
i
];
}
}
template
<
>
void
LaunchRemovePadding
(
half
*
output
,
const
half
*
input
,
const
int
*
token_offset
,
const
int
token_count
,
const
int
hidden_size
,
hipStream_t
stream
)
{
// input: [batch_size, sequence_length, hidden_size]
// output: [token_count, hidden_size]
// Make sure memory is aligned to 128 bit
ORT_ENFORCE
(
!
(
reinterpret_cast
<
size_t
>
(
input
)
&
0xF
)
&&
!
(
reinterpret_cast
<
size_t
>
(
output
)
&
0xF
),
"alignment"
);
if
(
hidden_size
%
8
==
0
)
{
const
int
width
=
hidden_size
/
8
;
const
int4
*
input2
=
reinterpret_cast
<
const
int4
*>
(
input
);
int4
*
output2
=
reinterpret_cast
<
int4
*>
(
output
);
hipLaunchKernelGGL
(
HIP_KERNEL_NAME
(
removePadding
<
int4
>
),
token_count
,
kMAX_THREADS_PER_BLOCK
,
0
,
stream
,
output2
,
input2
,
token_offset
,
width
);
}
else
if
(
hidden_size
%
4
==
0
)
{
const
int
width
=
hidden_size
/
4
;
const
int64_t
*
input2
=
reinterpret_cast
<
const
int64_t
*>
(
input
);
int64_t
*
output2
=
reinterpret_cast
<
int64_t
*>
(
output
);
hipLaunchKernelGGL
(
HIP_KERNEL_NAME
(
removePadding
<
int64_t
>
),
token_count
,
kMAX_THREADS_PER_BLOCK
,
0
,
stream
,
output2
,
input2
,
token_offset
,
width
);
}
else
if
(
hidden_size
%
2
==
0
)
{
const
int
width
=
hidden_size
/
2
;
const
int32_t
*
input2
=
reinterpret_cast
<
const
int32_t
*>
(
input
);
int32_t
*
output2
=
reinterpret_cast
<
int32_t
*>
(
output
);
hipLaunchKernelGGL
(
HIP_KERNEL_NAME
(
removePadding
<
int32_t
>
),
token_count
,
kMAX_THREADS_PER_BLOCK
,
0
,
stream
,
output2
,
input2
,
token_offset
,
width
);
}
else
{
const
int
width
=
hidden_size
;
const
int16_t
*
input2
=
reinterpret_cast
<
const
int16_t
*>
(
input
);
int16_t
*
output2
=
reinterpret_cast
<
int16_t
*>
(
output
);
hipLaunchKernelGGL
(
HIP_KERNEL_NAME
(
removePadding
<
int16_t
>
),
token_count
,
kMAX_THREADS_PER_BLOCK
,
0
,
stream
,
output2
,
input2
,
token_offset
,
width
);
}
}
template
<
>
void
LaunchRemovePadding
(
float
*
output
,
const
float
*
input
,
const
int
*
token_offset
,
const
int
token_count
,
const
int
hidden_size
,
hipStream_t
stream
)
{
ORT_ENFORCE
(
!
(
reinterpret_cast
<
size_t
>
(
input
)
&
0xF
)
&&
!
(
reinterpret_cast
<
size_t
>
(
output
)
&
0xF
),
"alignment"
);
if
(
hidden_size
%
4
==
0
)
{
const
int
width
=
hidden_size
/
4
;
const
int4
*
input2
=
reinterpret_cast
<
const
int4
*>
(
input
);
int4
*
output2
=
reinterpret_cast
<
int4
*>
(
output
);
hipLaunchKernelGGL
(
HIP_KERNEL_NAME
(
removePadding
<
int4
>
),
token_count
,
kMAX_THREADS_PER_BLOCK
,
0
,
stream
,
output2
,
input2
,
token_offset
,
width
);
}
else
if
(
hidden_size
%
2
==
0
)
{
const
int
width
=
hidden_size
/
2
;
const
int64_t
*
input2
=
reinterpret_cast
<
const
int64_t
*>
(
input
);
int64_t
*
output2
=
reinterpret_cast
<
int64_t
*>
(
output
);
hipLaunchKernelGGL
(
HIP_KERNEL_NAME
(
removePadding
<
int64_t
>
),
token_count
,
kMAX_THREADS_PER_BLOCK
,
0
,
stream
,
output2
,
input2
,
token_offset
,
width
);
}
else
{
const
int
width
=
hidden_size
;
const
int32_t
*
input2
=
reinterpret_cast
<
const
int32_t
*>
(
input
);
int32_t
*
output2
=
reinterpret_cast
<
int32_t
*>
(
output
);
hipLaunchKernelGGL
(
HIP_KERNEL_NAME
(
removePadding
<
int32_t
>
),
token_count
,
kMAX_THREADS_PER_BLOCK
,
0
,
stream
,
output2
,
input2
,
token_offset
,
width
);
}
}
// -----------------------------------
// Recover padding.
template
<
typename
T
>
__global__
void
__launch_bounds__
(
kMAX_THREADS_PER_BLOCK
)
restorePadding
(
T
*
target
,
const
T
*
source
,
const
int
*
token_offset
,
const
int
width
,
const
int
token_count
)
{
const
int
tid
=
threadIdx
.
x
;
const
int
token_index
=
blockIdx
.
x
;
const
int
target_seq_id
=
token_offset
[
token_index
];
const
int
source_seq_id
=
token_index
;
constexpr
T
padding_zero
=
0
;
if
(
token_index
<
token_count
)
{
for
(
int
i
=
tid
;
i
<
width
;
i
+=
blockDim
.
x
)
{
target
[
target_seq_id
*
width
+
i
]
=
source
[
source_seq_id
*
width
+
i
];
}
}
else
{
// It is padding: fill with zeros
for
(
int
i
=
tid
;
i
<
width
;
i
+=
blockDim
.
x
)
{
target
[
target_seq_id
*
width
+
i
]
=
padding_zero
;
}
}
}
template
<
>
__global__
void
__launch_bounds__
(
kMAX_THREADS_PER_BLOCK
)
restorePadding
(
int4
*
target
,
const
int4
*
source
,
const
int
*
token_offset
,
const
int
width
,
const
int
token_count
)
{
const
int
tid
=
threadIdx
.
x
;
const
int
token_index
=
blockIdx
.
x
;
const
int
target_seq_id
=
token_offset
[
token_index
];
const
int
source_seq_id
=
token_index
;
int4
padding_zero
{
0
,
0
,
0
,
0
};
if
(
token_index
<
token_count
)
{
for
(
int
i
=
tid
;
i
<
width
;
i
+=
blockDim
.
x
)
{
target
[
target_seq_id
*
width
+
i
]
=
source
[
source_seq_id
*
width
+
i
];
}
}
else
{
// It is padding: fill with zeros
for
(
int
i
=
tid
;
i
<
width
;
i
+=
blockDim
.
x
)
{
target
[
target_seq_id
*
width
+
i
]
=
padding_zero
;
}
}
}
template
<
>
void
LaunchRestorePadding
(
float
*
output
,
const
float
*
input
,
const
int
*
token_offset
,
const
int
token_count
,
const
int
hidden_size
,
const
int
batch_size
,
const
int
sequence_length
,
hipStream_t
stream
)
{
ORT_ENFORCE
(
!
(
reinterpret_cast
<
size_t
>
(
input
)
&
0xF
)
&&
!
(
reinterpret_cast
<
size_t
>
(
output
)
&
0xF
),
"alignment"
);
int
grid_size
=
batch_size
*
sequence_length
;
if
(
hidden_size
%
4
==
0
)
{
const
int
width
=
hidden_size
/
4
;
const
int4
*
input2
=
reinterpret_cast
<
const
int4
*>
(
input
);
int4
*
output2
=
reinterpret_cast
<
int4
*>
(
output
);
hipLaunchKernelGGL
(
HIP_KERNEL_NAME
(
restorePadding
<
int4
>
),
grid_size
,
kMAX_THREADS_PER_BLOCK
,
0
,
stream
,
output2
,
input2
,
token_offset
,
width
,
token_count
);
}
else
if
(
hidden_size
%
2
==
0
)
{
const
int
width
=
hidden_size
/
2
;
const
int64_t
*
input2
=
reinterpret_cast
<
const
int64_t
*>
(
input
);
int64_t
*
output2
=
reinterpret_cast
<
int64_t
*>
(
output
);
hipLaunchKernelGGL
(
HIP_KERNEL_NAME
(
restorePadding
<
int64_t
>
),
grid_size
,
kMAX_THREADS_PER_BLOCK
,
0
,
stream
,
output2
,
input2
,
token_offset
,
width
,
token_count
);
}
else
{
const
int
width
=
hidden_size
;
const
int32_t
*
input2
=
reinterpret_cast
<
const
int32_t
*>
(
input
);
int32_t
*
output2
=
reinterpret_cast
<
int32_t
*>
(
output
);
hipLaunchKernelGGL
(
HIP_KERNEL_NAME
(
restorePadding
<
int32_t
>
),
grid_size
,
kMAX_THREADS_PER_BLOCK
,
0
,
stream
,
output2
,
input2
,
token_offset
,
width
,
token_count
);
}
}
template
<
>
void
LaunchRestorePadding
(
half
*
output
,
const
half
*
input
,
const
int
*
token_offset
,
const
int
token_count
,
const
int
hidden_size
,
const
int
batch_size
,
const
int
sequence_length
,
hipStream_t
stream
)
{
// input: [token_count, hidden_size]
// output: [batch_size, sequence_length, hidden_size]
ORT_ENFORCE
(
!
(
reinterpret_cast
<
size_t
>
(
input
)
&
0xF
)
&&
!
(
reinterpret_cast
<
size_t
>
(
output
)
&
0xF
),
"alignment"
);
int
grid_size
=
batch_size
*
sequence_length
;
if
(
hidden_size
%
8
==
0
)
{
const
int
width
=
hidden_size
/
8
;
const
int4
*
input2
=
reinterpret_cast
<
const
int4
*>
(
input
);
int4
*
output2
=
reinterpret_cast
<
int4
*>
(
output
);
hipLaunchKernelGGL
(
HIP_KERNEL_NAME
(
restorePadding
<
int4
>
),
grid_size
,
kMAX_THREADS_PER_BLOCK
,
0
,
stream
,
output2
,
input2
,
token_offset
,
width
,
token_count
);
}
else
if
(
hidden_size
%
4
==
0
)
{
const
int
width
=
hidden_size
/
4
;
const
int64_t
*
input2
=
reinterpret_cast
<
const
int64_t
*>
(
input
);
int64_t
*
output2
=
reinterpret_cast
<
int64_t
*>
(
output
);
hipLaunchKernelGGL
(
HIP_KERNEL_NAME
(
restorePadding
<
int64_t
>
),
grid_size
,
kMAX_THREADS_PER_BLOCK
,
0
,
stream
,
output2
,
input2
,
token_offset
,
width
,
token_count
);
}
else
if
(
hidden_size
%
2
==
0
)
{
const
int
width
=
hidden_size
/
2
;
const
int32_t
*
input2
=
reinterpret_cast
<
const
int32_t
*>
(
input
);
int32_t
*
output2
=
reinterpret_cast
<
int32_t
*>
(
output
);
hipLaunchKernelGGL
(
HIP_KERNEL_NAME
(
restorePadding
<
int32_t
>
),
grid_size
,
kMAX_THREADS_PER_BLOCK
,
0
,
stream
,
output2
,
input2
,
token_offset
,
width
,
token_count
);
}
else
{
const
int
width
=
hidden_size
;
const
int16_t
*
input2
=
reinterpret_cast
<
const
int16_t
*>
(
input
);
int16_t
*
output2
=
reinterpret_cast
<
int16_t
*>
(
output
);
hipLaunchKernelGGL
(
HIP_KERNEL_NAME
(
restorePadding
<
int16_t
>
),
grid_size
,
kMAX_THREADS_PER_BLOCK
,
0
,
stream
,
output2
,
input2
,
token_offset
,
width
,
token_count
);
}
}
__global__
void
__launch_bounds__
(
kMAX_THREADS_PER_BLOCK
)
getTrtSequenceOffset
(
int
*
trt_mha_padding_offset
,
const
int
*
sequence_token_count
,
const
int
batch_size
)
{
extern
__shared__
int
tmp_offset
[];
if
(
threadIdx
.
x
==
0
)
{
tmp_offset
[
0
]
=
0
;
for
(
int
i
=
0
;
i
<
batch_size
;
i
++
)
{
tmp_offset
[
i
+
1
]
=
tmp_offset
[
i
]
+
sequence_token_count
[
i
];
}
}
__syncthreads
();
for
(
int
i
=
threadIdx
.
x
;
i
<
batch_size
+
1
;
i
+=
blockDim
.
x
)
{
trt_mha_padding_offset
[
i
]
=
tmp_offset
[
i
];
}
}
// Get sequence offset for TensorRT fused attention when there is no padding (or padding is removed)
void
LaunchTrtSequenceOffset
(
int
*
trt_mha_padding_offset
,
const
int
*
sequence_token_count
,
const
int
batch_size
,
hipStream_t
stream
)
{
hipLaunchKernelGGL
(
getTrtSequenceOffset
,
1
,
kMAX_THREADS_PER_BLOCK
,
sizeof
(
int
)
*
(
batch_size
+
1
),
stream
,
trt_mha_padding_offset
,
sequence_token_count
,
batch_size
);
}
__global__
void
__launch_bounds__
(
kMAX_THREADS_PER_BLOCK
)
getTrtSequenceOffset
(
int
*
trt_mha_padding_offset
,
const
int
*
sequence_token_count
,
const
int
batch_size
,
const
int
sequence_length
)
{
extern
__shared__
int
tmp_offset
[];
if
(
threadIdx
.
x
==
0
)
{
tmp_offset
[
0
]
=
0
;
// B for fused attention is 2 * batch_size
for
(
int
i
=
0
;
i
<
batch_size
;
i
++
)
{
tmp_offset
[
i
*
2
+
1
]
=
tmp_offset
[
i
*
2
]
+
sequence_token_count
[
i
];
tmp_offset
[
i
*
2
+
2
]
=
sequence_length
*
(
i
+
1
);
}
}
__syncthreads
();
for
(
int
i
=
threadIdx
.
x
;
i
<
2
*
batch_size
+
1
;
i
+=
blockDim
.
x
)
{
trt_mha_padding_offset
[
i
]
=
tmp_offset
[
i
];
}
}
// Get sequence offset for TensorRT fused attention when we keep the padding
void
LaunchTrtSequenceOffset
(
int
*
trt_mha_padding_offset
,
const
int
*
sequence_token_count
,
const
int
batch_size
,
const
int
sequence_length
,
hipStream_t
stream
)
{
hipLaunchKernelGGL
(
getTrtSequenceOffset
,
1
,
kMAX_THREADS_PER_BLOCK
,
sizeof
(
int
)
*
(
2
*
batch_size
+
1
),
stream
,
trt_mha_padding_offset
,
sequence_token_count
,
batch_size
,
sequence_length
);
}
}
// namespace rocm
}
// namespace contrib
}
// namespace onnxruntime
build/Linux/Release/amdgpu/onnxruntime/contrib_ops/rocm/bert/bert_padding.h
0 → 100644
View file @
1a91fcc2
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include "core/providers/rocm/shared_inc/rocm_utils.h"
#include <hip/hip_fp16.h>
#include <rocblas.h>
namespace
onnxruntime
{
namespace
contrib
{
namespace
rocm
{
// Build token indice for non-padding tokens and padding tokens.
void
LaunchGetTokenOffset
(
int
*
token_count_buffer
,
int
*
token_offset
,
int
*
cumulated_token_count
,
const
int
*
sequence_token_count
,
const
int
batch_size
,
const
int
sequence_length
,
hipStream_t
stream
);
// Remove paddings from input.
template
<
typename
T
>
void
LaunchRemovePadding
(
T
*
output
,
const
T
*
input
,
const
int
*
token_offset
,
const
int
token_count
,
const
int
hidden_size
,
hipStream_t
stream
);
// Rebuild paddings to restore output shape.
template
<
typename
T
>
void
LaunchRestorePadding
(
T
*
output
,
const
T
*
input
,
const
int
*
token_offset
,
const
int
token_count
,
const
int
hidden_size
,
const
int
batch_size
,
const
int
sequence_length
,
hipStream_t
stream
);
// Padding offset for TensorRT fused attention kernel
void
LaunchTrtSequenceOffset
(
int
*
trt_mha_padding_offset
,
const
int
*
mask_index
,
const
int
batch_size
,
hipStream_t
stream
);
void
LaunchTrtSequenceOffset
(
int
*
trt_mha_padding_offset
,
const
int
*
mask_index
,
const
int
batch_size
,
const
int
sequence_length
,
hipStream_t
stream
);
}
// namespace rocm
}
// namespace contrib
}
// namespace onnxruntime
build/Linux/Release/amdgpu/onnxruntime/contrib_ops/rocm/bert/decoder_attention.cc
0 → 100644
View file @
1a91fcc2
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "contrib_ops/rocm/bert/attention_impl.h"
#include "contrib_ops/rocm/bert/decoder_attention.h"
#include "contrib_ops/rocm/bert/transformer_rocm_common.h"
#include "core/framework/op_kernel.h"
#include "core/providers/rocm/shared_inc/fpgeneric.h"
using
namespace
onnxruntime
::
rocm
;
using
namespace
::
onnxruntime
::
common
;
using
namespace
ONNX_NAMESPACE
;
namespace
onnxruntime
{
namespace
contrib
{
namespace
rocm
{
#define REGISTER_KERNEL_TYPED(T) \
ONNX_OPERATOR_TYPED_KERNEL_EX( \
DecoderAttention, \
kMSDomain, \
1, \
T, \
kRocmExecutionProvider, \
(*KernelDefBuilder::Create()) \
.TypeConstraint("T", DataTypeImpl::GetTensorType<T>()), \
DecoderAttention<T>);
REGISTER_KERNEL_TYPED
(
float
)
REGISTER_KERNEL_TYPED
(
MLFloat16
)
namespace
{
Status
CheckInputs
(
const
TensorShape
&
query_shape
,
const
TensorShape
&
key_shape
,
const
TensorShape
&
q_weights_shape
,
const
TensorShape
&
kv_weights_shape
,
const
TensorShape
&
bias_shape
,
const
Tensor
*
key_padding_mask
,
const
Tensor
*
key_cache
,
const
Tensor
*
value_cache
,
const
bool
static_kv
,
const
bool
use_past
,
const
bool
has_layer_state
,
const
bool
has_key_padding_mask
)
{
const
auto
&
query_shape_dims
=
query_shape
.
GetDims
();
if
(
query_shape_dims
.
size
()
!=
3
)
{
return
ORT_MAKE_STATUS
(
ONNXRUNTIME
,
INVALID_ARGUMENT
,
"Input 'query' is expected to have 3 dimensions, got "
,
query_shape_dims
.
size
());
}
int
sequence_length
=
static_cast
<
int
>
(
query_shape_dims
[
0
]);
int
batch_size
=
static_cast
<
int
>
(
query_shape_dims
[
1
]);
int
hidden_size
=
static_cast
<
int
>
(
query_shape_dims
[
2
]);
const
auto
&
key_shape_dims
=
key_shape
.
GetDims
();
if
(
key_shape_dims
.
size
()
!=
3
)
{
return
ORT_MAKE_STATUS
(
ONNXRUNTIME
,
INVALID_ARGUMENT
,
"Input 'key' is expected to have 3 dimensions, got "
,
key_shape_dims
.
size
());
}
int
kv_sequence_length
=
static_cast
<
int
>
(
key_shape_dims
[
0
]);
if
(
query_shape_dims
[
1
]
!=
key_shape_dims
[
1
])
{
return
ORT_MAKE_STATUS
(
ONNXRUNTIME
,
INVALID_ARGUMENT
,
"query and key shall have the same batch size"
);
}
if
(
query_shape_dims
[
2
]
!=
key_shape_dims
[
2
])
{
return
ORT_MAKE_STATUS
(
ONNXRUNTIME
,
INVALID_ARGUMENT
,
"query and key shall have the same hidden size"
);
}
const
auto
&
q_weights_dims
=
q_weights_shape
.
GetDims
();
if
(
q_weights_dims
.
size
()
!=
2
)
{
return
ORT_MAKE_STATUS
(
ONNXRUNTIME
,
INVALID_ARGUMENT
,
"Input 'q_weights' is expected to have 2 dimensions, got "
,
q_weights_dims
.
size
());
}
const
auto
&
kv_weights_dims
=
kv_weights_shape
.
GetDims
();
if
(
kv_weights_dims
.
size
()
!=
2
)
{
return
ORT_MAKE_STATUS
(
ONNXRUNTIME
,
INVALID_ARGUMENT
,
"Input 'kv_weights' is expected to have 2 dimensions, got "
,
kv_weights_dims
.
size
());
}
if
(
q_weights_dims
[
0
]
!=
hidden_size
||
q_weights_dims
[
1
]
!=
hidden_size
)
{
return
ORT_MAKE_STATUS
(
ONNXRUNTIME
,
INVALID_ARGUMENT
,
"q_weights shall have shape (hidden size, hidden size)"
);
}
if
(
kv_weights_dims
[
0
]
!=
hidden_size
||
kv_weights_dims
[
1
]
!=
2
*
static_cast
<
int64_t
>
(
hidden_size
))
{
return
ORT_MAKE_STATUS
(
ONNXRUNTIME
,
INVALID_ARGUMENT
,
"kv_weights shall have shape (hidden size, 2 * hidden size)"
);
}
const
auto
&
bias_dims
=
bias_shape
.
GetDims
();
if
(
bias_dims
.
size
()
!=
1
)
{
return
ORT_MAKE_STATUS
(
ONNXRUNTIME
,
INVALID_ARGUMENT
,
"Input 'bias' is expected to have 1 dimension, got "
,
bias_dims
.
size
());
}
if
(
bias_dims
[
0
]
!=
3
*
static_cast
<
int64_t
>
(
hidden_size
))
{
return
ORT_MAKE_STATUS
(
ONNXRUNTIME
,
INVALID_ARGUMENT
,
"bias shall have shape (3 * hidden size)"
);
}
int
key_length
=
kv_sequence_length
;
if
(
key_padding_mask
!=
nullptr
&&
has_key_padding_mask
==
true
)
{
const
auto
&
kp_mask_dims
=
key_padding_mask
->
Shape
().
GetDims
();
if
(
kp_mask_dims
.
size
()
!=
2
)
{
return
ORT_MAKE_STATUS
(
ONNXRUNTIME
,
INVALID_ARGUMENT
,
"Input 'key_padding_mask' is expected to have 2 dimension, got "
,
kp_mask_dims
.
size
());
}
if
(
kp_mask_dims
[
0
]
!=
batch_size
)
{
return
ORT_MAKE_STATUS
(
ONNXRUNTIME
,
INVALID_ARGUMENT
,
"key_padding_mask shall have same batch size with query"
);
}
if
(
!
has_layer_state
||
!
use_past
)
{
if
(
!
static_kv
)
{
key_length
=
sequence_length
;
}
}
else
{
if
(
!
static_kv
)
{
key_length
=
sequence_length
+
kv_sequence_length
;
}
}
if
(
kp_mask_dims
[
1
]
!=
key_length
)
{
return
ORT_MAKE_STATUS
(
ONNXRUNTIME
,
INVALID_ARGUMENT
,
"key_padding_mask shall have same sequence length as generated key"
);
}
}
if
(
key_cache
!=
nullptr
&&
value_cache
!=
nullptr
&&
has_layer_state
&&
use_past
)
{
const
auto
&
key_cache_dims
=
key_cache
->
Shape
().
GetDims
();
if
(
key_cache_dims
.
size
()
!=
4
)
{
return
ORT_MAKE_STATUS
(
ONNXRUNTIME
,
INVALID_ARGUMENT
,
"Input 'key_cache' is expected to have 4 dimension, got "
,
key_cache_dims
.
size
());
}
const
auto
&
value_cache_dims
=
value_cache
->
Shape
().
GetDims
();
if
(
value_cache_dims
.
size
()
!=
4
)
{
return
ORT_MAKE_STATUS
(
ONNXRUNTIME
,
INVALID_ARGUMENT
,
"Input 'value_cache' is expected to have 4 dimension, got "
,
value_cache_dims
.
size
());
}
if
(
key_cache_dims
[
0
]
!=
batch_size
)
{
return
ORT_MAKE_STATUS
(
ONNXRUNTIME
,
INVALID_ARGUMENT
,
"key_cache shall have same batch size as query"
);
}
if
(
value_cache_dims
[
0
]
!=
batch_size
)
{
return
ORT_MAKE_STATUS
(
ONNXRUNTIME
,
INVALID_ARGUMENT
,
"value_cache shall have same batch size as query"
);
}
if
(
key_cache_dims
[
1
]
*
key_cache_dims
[
3
]
!=
hidden_size
)
{
return
ORT_MAKE_STATUS
(
ONNXRUNTIME
,
INVALID_ARGUMENT
,
"key_cache shall have correct hidden size"
);
}
if
(
value_cache_dims
[
1
]
*
value_cache_dims
[
3
]
!=
hidden_size
)
{
return
ORT_MAKE_STATUS
(
ONNXRUNTIME
,
INVALID_ARGUMENT
,
"value_cache shall have correct hidden size"
);
}
}
return
Status
::
OK
();
}
}
// anonymous namespace
template
<
typename
T
>
DecoderAttention
<
T
>::
DecoderAttention
(
const
OpKernelInfo
&
info
)
:
RocmKernel
(
info
)
{
int64_t
num_heads
=
0
;
ORT_ENFORCE
(
info
.
GetAttr
(
"num_heads"
,
&
num_heads
).
IsOK
()
&&
num_heads
>
0
);
num_heads_
=
static_cast
<
int
>
(
num_heads
);
}
template
<
typename
T
>
Status
DecoderAttention
<
T
>::
ComputeInternal
(
OpKernelContext
*
context
)
const
{
const
Tensor
*
query
(
context
->
Input
<
Tensor
>
(
0
));
const
Tensor
*
key
(
context
->
Input
<
Tensor
>
(
1
));
const
Tensor
*
q_weights
(
context
->
Input
<
Tensor
>
(
2
));
const
Tensor
*
kv_weights
(
context
->
Input
<
Tensor
>
(
3
));
const
Tensor
*
bias
(
context
->
Input
<
Tensor
>
(
4
));
const
Tensor
*
key_padding_mask
(
context
->
Input
<
Tensor
>
(
5
));
const
Tensor
*
key_cache
(
context
->
Input
<
Tensor
>
(
6
));
const
Tensor
*
value_cache
(
context
->
Input
<
Tensor
>
(
7
));
const
Tensor
*
static_kv
(
context
->
Input
<
Tensor
>
(
8
));
const
Tensor
*
use_past
(
context
->
Input
<
Tensor
>
(
9
));
const
Tensor
*
has_layer_state
(
context
->
Input
<
Tensor
>
(
10
));
const
Tensor
*
has_key_padding_mask
(
context
->
Input
<
Tensor
>
(
11
));
hipStream_t
stream
=
Stream
();
// Copy static_kv, use_past and has_layer_state to CPU
auto
pinned_buffer
=
AllocateBufferOnCPUPinned
<
void
>
(
4
*
sizeof
(
bool
));
bool
*
kernel_state_pinned
=
reinterpret_cast
<
bool
*>
(
pinned_buffer
.
get
());
HIP_RETURN_IF_ERROR
(
hipMemcpyAsync
(
kernel_state_pinned
,
static_kv
->
Data
<
bool
>
(),
sizeof
(
bool
),
hipMemcpyDeviceToHost
,
stream
));
HIP_RETURN_IF_ERROR
(
hipMemcpyAsync
(
kernel_state_pinned
+
1
,
use_past
->
Data
<
bool
>
(),
sizeof
(
bool
),
hipMemcpyDeviceToHost
,
stream
));
HIP_RETURN_IF_ERROR
(
hipMemcpyAsync
(
kernel_state_pinned
+
2
,
has_layer_state
->
Data
<
bool
>
(),
sizeof
(
bool
),
hipMemcpyDeviceToHost
,
stream
));
HIP_RETURN_IF_ERROR
(
hipMemcpyAsync
(
kernel_state_pinned
+
3
,
has_key_padding_mask
->
Data
<
bool
>
(),
sizeof
(
bool
),
hipMemcpyDeviceToHost
,
stream
));
// Create an event to make sure the async copy is finished before reading the data.
AutoDestoryCudaEvent
new_event
;
hipEvent_t
&
isCopyDone
=
new_event
.
Get
();
HIP_RETURN_IF_ERROR
(
hipEventCreate
(
&
isCopyDone
));
HIP_RETURN_IF_ERROR
(
hipEventRecord
(
isCopyDone
,
stream
));
auto
&
device_prop
=
GetDeviceProp
();
// query shape (batch_size, sequence_length, input_hidden_size)
const
auto
&
query_shape
=
query
->
Shape
();
int
sequence_length
=
static_cast
<
int
>
(
query_shape
[
0
]);
int
batch_size
=
static_cast
<
int
>
(
query_shape
[
1
]);
int
hidden_size
=
static_cast
<
int
>
(
query_shape
[
2
]);
const
auto
&
key_shape
=
key
->
Shape
();
int
key_sequence_length
=
static_cast
<
int
>
(
key_shape
[
0
]);
int
head_size
=
hidden_size
/
num_heads_
;
//k, v sequence after gemm
int
kv_sequence_length
=
0
;
// Generate q, k, v w/o cache
// query input: (S, B, h1)
// key input: (S', B, h1)
// weight: (h1, h2)
// h = N*H
rocblas_handle
rocblas
=
RocblasHandle
();
ROCBLAS_RETURN_IF_ERROR
(
rocblas_set_stream
(
rocblas
,
stream
));
constexpr
size_t
element_size
=
sizeof
(
T
);
typedef
typename
ToHipType
<
T
>::
MappedType
HipT
;
HipT
one
=
ToHipType
<
T
>::
FromFloat
(
1.0
f
);
HipT
zero
=
ToHipType
<
T
>::
FromFloat
(
0.0
f
);
int
m
=
0
,
n
=
0
,
k
=
0
;
IAllocatorUniquePtr
<
T
>
gemm_query_buffer_p
(
nullptr
);
IAllocatorUniquePtr
<
T
>
gemm_kv_buffer_p
(
nullptr
);
HIP_RETURN_IF_ERROR
(
hipEventSynchronize
(
isCopyDone
));
bool
static_kv_
=
*
kernel_state_pinned
;
bool
use_past_
=
*
(
kernel_state_pinned
+
1
);
bool
has_layer_state_
=
*
(
kernel_state_pinned
+
2
);
bool
has_key_padding_mask_
=
*
(
kernel_state_pinned
+
3
);
ORT_RETURN_IF_ERROR
(
CheckInputs
(
query
->
Shape
(),
key
->
Shape
(),
q_weights
->
Shape
(),
kv_weights
->
Shape
(),
bias
->
Shape
(),
key_padding_mask
,
key_cache
,
value_cache
,
static_kv_
,
use_past_
,
has_layer_state_
,
has_key_padding_mask_
)
);
// calculate q
gemm_query_buffer_p
=
GetScratchBuffer
<
T
>
(
batch_size
*
sequence_length
*
hidden_size
*
element_size
);
m
=
sequence_length
*
batch_size
;
n
=
hidden_size
;
k
=
hidden_size
;
// TODO(tianleiwu): fuse bias and transpose
// broadcast bias for query: (h2, S*B)
ROCBLAS_RETURN_IF_ERROR
(
rocblasGemmHelper
(
rocblas
,
rocblas_operation_none
,
rocblas_operation_none
,
n
,
m
,
1
,
&
one
,
reinterpret_cast
<
const
HipT
*>
(
bias
->
Data
<
T
>
()),
n
,
GetConstOnes
<
HipT
>
(
m
),
1
,
&
zero
,
reinterpret_cast
<
HipT
*>
(
gemm_query_buffer_p
.
get
()),
n
,
device_prop
));
// matmul: (h2, h1)*(h1, S*B)
ROCBLAS_RETURN_IF_ERROR
(
rocblasGemmHelper
(
rocblas
,
rocblas_operation_none
,
rocblas_operation_none
,
n
,
m
,
k
,
&
one
,
reinterpret_cast
<
const
HipT
*>
(
q_weights
->
Data
<
T
>
()),
n
,
reinterpret_cast
<
const
HipT
*>
(
query
->
Data
<
T
>
()),
k
,
&
one
,
reinterpret_cast
<
HipT
*>
(
gemm_query_buffer_p
.
get
()),
n
,
device_prop
));
// gemm_query_buffer in col-base: (h2, S*B)
// calcualte k, v
n
=
2
*
hidden_size
;
k
=
hidden_size
;
if
(
!
has_layer_state_
||
!
use_past_
)
{
if
(
!
static_kv_
)
{
gemm_kv_buffer_p
=
GetScratchBuffer
<
T
>
(
batch_size
*
2
*
sequence_length
*
hidden_size
*
element_size
);
m
=
sequence_length
*
batch_size
;
n
=
2
*
hidden_size
;
k
=
hidden_size
;
kv_sequence_length
=
sequence_length
;
// broadcast bias for key and value: (2*h2, T_S*B)
ROCBLAS_RETURN_IF_ERROR
(
rocblasGemmHelper
(
rocblas
,
rocblas_operation_none
,
rocblas_operation_none
,
n
,
m
,
1
,
&
one
,
reinterpret_cast
<
const
HipT
*>
(
bias
->
Data
<
T
>
()
+
hidden_size
),
n
,
GetConstOnes
<
HipT
>
(
m
),
1
,
&
zero
,
reinterpret_cast
<
HipT
*>
(
gemm_kv_buffer_p
.
get
()),
n
,
device_prop
));
// matmul: (2*h2, h1)*(h1, T_S*B)
ROCBLAS_RETURN_IF_ERROR
(
rocblasGemmHelper
(
rocblas
,
rocblas_operation_none
,
rocblas_operation_none
,
n
,
m
,
k
,
&
one
,
reinterpret_cast
<
const
HipT
*>
(
kv_weights
->
Data
<
T
>
()),
n
,
reinterpret_cast
<
const
HipT
*>
(
query
->
Data
<
T
>
()),
k
,
&
one
,
reinterpret_cast
<
HipT
*>
(
gemm_kv_buffer_p
.
get
()),
n
,
device_prop
));
// gemm_kv_buffer in col-base: (2*h2, T_S*B)
}
else
{
gemm_kv_buffer_p
=
GetScratchBuffer
<
T
>
(
batch_size
*
2
*
key_sequence_length
*
hidden_size
*
element_size
);
m
=
key_sequence_length
*
batch_size
;
n
=
2
*
hidden_size
;
k
=
hidden_size
;
kv_sequence_length
=
key_sequence_length
;
// broadcast bias for key and value: (2*h2, T_S*B)
ROCBLAS_RETURN_IF_ERROR
(
rocblasGemmHelper
(
rocblas
,
rocblas_operation_none
,
rocblas_operation_none
,
n
,
m
,
1
,
&
one
,
reinterpret_cast
<
const
HipT
*>
(
bias
->
Data
<
T
>
()
+
hidden_size
),
n
,
GetConstOnes
<
HipT
>
(
m
),
1
,
&
zero
,
reinterpret_cast
<
HipT
*>
(
gemm_kv_buffer_p
.
get
()),
n
,
device_prop
));
// matmul: (2*h2, h1)*(h1, T_S*B)
ROCBLAS_RETURN_IF_ERROR
(
rocblasGemmHelper
(
rocblas
,
rocblas_operation_none
,
rocblas_operation_none
,
n
,
m
,
k
,
&
one
,
reinterpret_cast
<
const
HipT
*>
(
kv_weights
->
Data
<
T
>
()),
n
,
reinterpret_cast
<
const
HipT
*>
(
key
->
Data
<
T
>
()),
k
,
&
one
,
reinterpret_cast
<
HipT
*>
(
gemm_kv_buffer_p
.
get
()),
n
,
device_prop
));
// gemm_kv_buffer in col-base: (2*h2, T_S*B)
}
}
else
{
ORT_ENFORCE
(
nullptr
!=
key_cache
&&
nullptr
!=
value_cache
);
// (B, N, S, H)
const
auto
&
cache_shape
=
key_cache
->
Shape
();
// key and value cache have identical shape
int
cache_sequence_length
=
static_cast
<
int
>
(
cache_shape
[
2
]);
if
(
!
static_kv_
)
{
gemm_kv_buffer_p
=
GetScratchBuffer
<
T
>
(
batch_size
*
2
*
sequence_length
*
hidden_size
*
element_size
);
m
=
sequence_length
*
batch_size
;
kv_sequence_length
=
cache_sequence_length
+
sequence_length
;
// broadcast bias for key and value: (2*h2, T_S*B)
ROCBLAS_RETURN_IF_ERROR
(
rocblasGemmHelper
(
rocblas
,
rocblas_operation_none
,
rocblas_operation_none
,
n
,
m
,
1
,
&
one
,
reinterpret_cast
<
const
HipT
*>
(
bias
->
Data
<
T
>
()
+
hidden_size
),
n
,
GetConstOnes
<
HipT
>
(
m
),
1
,
&
zero
,
reinterpret_cast
<
HipT
*>
(
gemm_kv_buffer_p
.
get
()),
n
,
device_prop
));
// matmul: (2*h2, h1)*(h1, T_S*B)
ROCBLAS_RETURN_IF_ERROR
(
rocblasGemmHelper
(
rocblas
,
rocblas_operation_none
,
rocblas_operation_none
,
n
,
m
,
k
,
&
one
,
reinterpret_cast
<
const
HipT
*>
(
kv_weights
->
Data
<
T
>
()),
n
,
reinterpret_cast
<
const
HipT
*>
(
query
->
Data
<
T
>
()),
k
,
&
one
,
reinterpret_cast
<
HipT
*>
(
gemm_kv_buffer_p
.
get
()),
n
,
device_prop
));
// gemm_kv_buffer in col-base: (2*h2, T_S*B)
}
else
{
kv_sequence_length
=
cache_sequence_length
;
}
}
size_t
bytes
=
element_size
*
batch_size
*
(
sequence_length
+
2
*
kv_sequence_length
)
*
hidden_size
;
auto
qkv_buffer_p
=
GetScratchBuffer
<
void
>
(
bytes
);
bytes
=
element_size
*
2
*
batch_size
*
sequence_length
*
num_heads_
*
(
2
*
head_size
+
kv_sequence_length
);
auto
workspace_p
=
GetScratchBuffer
<
void
>
(
bytes
);
Tensor
*
output
(
context
->
Output
(
0
,
query_shape
));
TensorShape
new_cache_shape
({
batch_size
,
num_heads_
,
kv_sequence_length
,
head_size
});
Tensor
*
new_key_cache
(
context
->
Output
(
1
,
new_cache_shape
));
Tensor
*
new_value_cache
(
context
->
Output
(
2
,
new_cache_shape
));
return
LaunchDecoderAttentionKernel
(
device_prop
,
stream
,
rocblas
,
element_size
,
batch_size
,
sequence_length
,
kv_sequence_length
,
num_heads_
,
head_size
,
static_kv_
,
use_past_
,
has_layer_state_
,
has_key_padding_mask_
,
nullptr
==
gemm_query_buffer_p
?
nullptr
:
reinterpret_cast
<
const
HipT
*>
(
gemm_query_buffer_p
.
get
()),
nullptr
==
gemm_kv_buffer_p
?
nullptr
:
reinterpret_cast
<
const
HipT
*>
(
gemm_kv_buffer_p
.
get
()),
nullptr
==
key_padding_mask
?
nullptr
:
key_padding_mask
->
Data
<
bool
>
(),
nullptr
==
key_cache
?
nullptr
:
key_cache
->
Data
<
T
>
(),
nullptr
==
value_cache
?
nullptr
:
value_cache
->
Data
<
T
>
(),
qkv_buffer_p
.
get
(),
workspace_p
.
get
(),
output
->
MutableData
<
T
>
(),
nullptr
==
new_key_cache
?
nullptr
:
new_key_cache
->
MutableData
<
T
>
(),
nullptr
==
new_value_cache
?
nullptr
:
new_value_cache
->
MutableData
<
T
>
());
}
}
// namespace rocm
}
// namespace contrib
}
// namespace onnxruntime
build/Linux/Release/amdgpu/onnxruntime/contrib_ops/rocm/bert/decoder_attention.h
0 → 100644
View file @
1a91fcc2
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include "core/common/common.h"
#include "core/providers/rocm/rocm_kernel.h"
namespace
onnxruntime
{
namespace
contrib
{
namespace
rocm
{
using
namespace
onnxruntime
::
rocm
;
template
<
typename
T
>
class
DecoderAttention
final
:
public
RocmKernel
{
public:
DecoderAttention
(
const
OpKernelInfo
&
info
);
Status
ComputeInternal
(
OpKernelContext
*
context
)
const
override
;
private:
int
num_heads_
;
};
}
// namespace rocm
}
// namespace contrib
}
// namespace onnxruntime
build/Linux/Release/amdgpu/onnxruntime/contrib_ops/rocm/bert/layer_norm.cuh
0 → 100644
View file @
1a91fcc2
#include "hip/hip_runtime.h"
/*
The implementation of this file is based on bert plugins in TensorRT demo:
https://github.com/NVIDIA/TensorRT/tree/release/5.1/demo/BERT/
Copyright 2019 NVIDIA Corporation
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include "core/providers/rocm/rocm_common.h"
#include "core/providers/rocm/cu_inc/common.cuh"
#include "core/providers/rocm/shared_inc/rocm_call.h"
#include <hip/hip_fp16.h>
#include <rocblas.h>
#include <hipcub/hipcub.hpp>
using
namespace
onnxruntime
::
rocm
;
using
namespace
hipcub
;
namespace
onnxruntime
{
namespace
contrib
{
namespace
rocm
{
template
<
typename
T
>
__device__
inline
T
Rsqrt
(
const
T
&
x
);
template
<
>
__device__
inline
float
Rsqrt
(
const
float
&
x
)
{
return
rsqrtf
(
x
);
}
template
<
>
__device__
inline
half
Rsqrt
(
const
half
&
x
)
{
#if __CUDA_ARCH__ >= 530 || !defined(__CUDA_ARCH__)
return
hrsqrt
(
x
);
#else
return
half
(
rsqrtf
(
float
(
x
)));
#endif
}
__device__
inline
half2
AddHalf2
(
const
half2
a
,
const
half2
b
)
{
#if __CUDA_ARCH__ >= 530 || !defined(__CUDA_ARCH__)
return
__hadd2
(
a
,
b
);
#else
return
__halves2half2
(
__hadd
(
a
.
x
,
b
.
x
),
__hadd
(
a
.
y
,
b
.
y
));
#endif
}
struct
KeyValuePairSum
{
__device__
inline
hipcub
::
KeyValuePair
<
float
,
float
>
operator
()(
const
hipcub
::
KeyValuePair
<
float
,
float
>&
a
,
const
hipcub
::
KeyValuePair
<
float
,
float
>&
b
)
{
return
hipcub
::
KeyValuePair
<
float
,
float
>
(
a
.
key
+
b
.
key
,
a
.
value
+
b
.
value
);
}
__device__
inline
hipcub
::
KeyValuePair
<
half
,
half
>
operator
()(
const
hipcub
::
KeyValuePair
<
half
,
half
>&
a
,
const
hipcub
::
KeyValuePair
<
half
,
half
>&
b
)
{
const
half2
a2
=
__halves2half2
(
a
.
key
,
a
.
value
);
const
half2
b2
=
__halves2half2
(
b
.
key
,
b
.
value
);
const
half2
res
=
AddHalf2
(
a2
,
b2
);
return
hipcub
::
KeyValuePair
<
half
,
half
>
(
__low2half
(
res
),
__high2half
(
res
));
}
__device__
inline
hipcub
::
KeyValuePair
<
half2
,
half2
>
operator
()(
const
hipcub
::
KeyValuePair
<
half2
,
half2
>&
a
,
const
hipcub
::
KeyValuePair
<
half2
,
half2
>&
b
)
{
return
hipcub
::
KeyValuePair
<
half2
,
half2
>
(
AddHalf2
(
a
.
key
,
b
.
key
),
AddHalf2
(
a
.
value
,
b
.
value
));
}
};
template
<
typename
T
,
int
TPB
>
__device__
inline
void
LayerNorm
(
const
hipcub
::
KeyValuePair
<
T
,
T
>&
thread_data
,
const
int
ld
,
const
int
offset
,
const
T
*
beta
,
const
T
*
gamma
,
const
T
epsilon
,
T
*
output
)
{
// Assuming thread_data is already divided by ld
using
BlockReduce
=
hipcub
::
BlockReduce
<
hipcub
::
KeyValuePair
<
T
,
T
>
,
TPB
>
;
__shared__
typename
BlockReduce
::
TempStorage
temp_storage
;
__shared__
T
mu
;
// mean
__shared__
T
rsigma
;
// 1 / std.dev.
KeyValuePairSum
pair_sum
;
const
auto
sum_kv
=
BlockReduce
(
temp_storage
).
Reduce
(
thread_data
,
pair_sum
);
if
(
threadIdx
.
x
==
0
)
{
mu
=
sum_kv
.
key
;
rsigma
=
Rsqrt
(
sum_kv
.
value
-
mu
*
mu
+
epsilon
);
}
__syncthreads
();
for
(
int
i
=
threadIdx
.
x
;
i
<
ld
;
i
+=
TPB
)
{
const
int
idx
=
offset
+
i
;
const
T
val
=
output
[
idx
];
const
T
g
(
gamma
[
i
]);
const
T
b
=
(
nullptr
==
beta
)
?
(
T
)
0
:
beta
[
i
];
output
[
idx
]
=
g
*
(
val
-
mu
)
*
rsigma
+
b
;
}
}
template
<
typename
T
,
int
TPB
,
int
ILP
>
__device__
inline
void
LayerNormSmall
(
const
T
*
input_v
,
const
hipcub
::
KeyValuePair
<
T
,
T
>&
thread_data
,
const
int
ld
,
const
int
idx
,
const
T
*
beta
,
const
T
*
gamma
,
const
T
epsilon
,
T
*
output
)
{
// Assuming thread_data is already divided by ld
// Small settings: the block covers the leading dimension TPB >= ld. The input
// value is available in a register
using
VecT
=
aligned_vector
<
T
,
ILP
>
;
using
BlockReduce
=
hipcub
::
BlockReduce
<
hipcub
::
KeyValuePair
<
T
,
T
>
,
TPB
>
;
__shared__
typename
BlockReduce
::
TempStorage
temp_storage
;
__shared__
T
mu
;
// mean
__shared__
T
rsigma
;
// 1 / std.dev.
T
beta_v
[
ILP
],
gamma_v
[
ILP
],
output_v
[
ILP
];
if
(
beta
!=
nullptr
)
{
VecT
*
beta_val
=
reinterpret_cast
<
VecT
*>
(
&
beta_v
);
*
beta_val
=
*
reinterpret_cast
<
const
VecT
*>
(
&
beta
[
threadIdx
.
x
*
ILP
]);
}
VecT
*
gamma_val
=
reinterpret_cast
<
VecT
*>
(
&
gamma_v
);
*
gamma_val
=
*
reinterpret_cast
<
const
VecT
*>
(
&
gamma
[
threadIdx
.
x
*
ILP
]);
VecT
*
output_val
=
reinterpret_cast
<
VecT
*>
(
&
output_v
);
KeyValuePairSum
pair_sum
;
const
hipcub
::
KeyValuePair
<
T
,
T
>
sum_kv
=
BlockReduce
(
temp_storage
).
Reduce
(
thread_data
,
pair_sum
);
if
(
threadIdx
.
x
==
0
)
{
mu
=
sum_kv
.
key
;
rsigma
=
Rsqrt
(
sum_kv
.
value
-
mu
*
mu
+
epsilon
);
}
__syncthreads
();
if
(
ILP
*
threadIdx
.
x
<
ld
)
{
#pragma unroll
for
(
int
i
=
0
;
i
<
ILP
;
i
++
)
{
output_v
[
i
]
=
(
beta
!=
nullptr
)
?
gamma_v
[
i
]
*
(
input_v
[
i
]
-
mu
)
*
rsigma
+
beta_v
[
i
]
:
gamma_v
[
i
]
*
(
input_v
[
i
]
-
mu
)
*
rsigma
;
}
*
(
reinterpret_cast
<
VecT
*>
(
&
output
[
idx
]))
=
*
output_val
;
}
}
}
// namespace rocm
}
// namespace contrib
}
// namespace onnxruntime
build/Linux/Release/amdgpu/onnxruntime/contrib_ops/rocm/bert/longformer_attention.cc
0 → 100644
View file @
1a91fcc2
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "core/providers/rocm/shared_inc/fpgeneric.h"
#include "core/platform/env_var_utils.h"
#include "contrib_ops/rocm/bert/longformer_global_impl.h"
#include "contrib_ops/rocm/bert/longformer_attention_impl.h"
#include "contrib_ops/rocm/bert/transformer_rocm_common.h"
#include "contrib_ops/rocm/bert/longformer_attention.h"
using
namespace
onnxruntime
::
rocm
;
using
namespace
::
onnxruntime
::
common
;
using
namespace
ONNX_NAMESPACE
;
namespace
onnxruntime
{
namespace
contrib
{
namespace
rocm
{
#define REGISTER_KERNEL_TYPED(T) \
ONNX_OPERATOR_TYPED_KERNEL_EX( \
LongformerAttention, \
kMSDomain, \
1, \
T, \
kRocmExecutionProvider, \
(*KernelDefBuilder::Create()) \
.TypeConstraint("T", DataTypeImpl::GetTensorType<T>()), \
LongformerAttention<T>);
REGISTER_KERNEL_TYPED
(
float
)
REGISTER_KERNEL_TYPED
(
MLFloat16
)
template
<
typename
T
>
LongformerAttention
<
T
>::
LongformerAttention
(
const
OpKernelInfo
&
info
)
:
RocmKernel
(
info
),
LongformerAttentionBase
(
info
)
{
use_compact_memory_
=
ParseEnvironmentVariableWithDefault
<
bool
>
(
longformer
::
kUseCompactMemory
,
true
);
use_half4_
=
ParseEnvironmentVariableWithDefault
<
bool
>
(
longformer
::
kUseHalf4
,
true
);
}
template
<
typename
T
>
Status
LongformerAttention
<
T
>::
ComputeInternal
(
OpKernelContext
*
context
)
const
{
const
Tensor
*
input
=
context
->
Input
<
Tensor
>
(
0
);
const
Tensor
*
weights
=
context
->
Input
<
Tensor
>
(
1
);
const
Tensor
*
bias
=
context
->
Input
<
Tensor
>
(
2
);
const
Tensor
*
attention_mask
=
context
->
Input
<
Tensor
>
(
3
);
const
Tensor
*
global_weights
=
context
->
Input
<
Tensor
>
(
4
);
const
Tensor
*
global_bias
=
context
->
Input
<
Tensor
>
(
5
);
const
Tensor
*
global_attention_mask
=
context
->
Input
<
Tensor
>
(
6
);
ORT_RETURN_IF_ERROR
(
CheckInputs
(
input
->
Shape
(),
weights
->
Shape
(),
bias
->
Shape
(),
attention_mask
->
Shape
(),
global_weights
->
Shape
(),
global_bias
->
Shape
(),
global_attention_mask
->
Shape
()));
// Input shapes:
// input : (batch_size, sequence_length, hidden_size)
// weights : (hidden_size, 3 * hidden_size) -- format 1
// (3, hidden_size, hidden_size) -- format 0
// bias : (3 * hidden_size) -- format 1 (bias for Q, K, V)
// (5 * hidden_size) -- format 0 (bias for Q, K, V, Global_K, Global_V)
// attention_mask : (batch_size, sequence_length)
// global_weights : (hidden_size, 3 * hidden_size) -- format 1
// (3, hidden_size, hidden_size) -- format 0
// global_bias : (3 * hidden_size) -- format 1 (bias for Global_Q, Global_K, Global_V)
// (1 * hidden_size) -- format 0 (bias for Global_Q)
// global_attention_mask : (batch_size, sequence_length)
// Output shapes:
// output : (batch_size, sequence_length, hidden_size)
const
auto
&
shape
=
input
->
Shape
();
int
batch_size
=
static_cast
<
int
>
(
shape
[
0
]);
int
sequence_length
=
static_cast
<
int
>
(
shape
[
1
]);
int
hidden_size
=
static_cast
<
int
>
(
shape
[
2
]);
int
head_size
=
hidden_size
/
num_heads_
;
Tensor
*
output
=
context
->
Output
(
0
,
shape
);
rocblas_handle
rocblas
=
RocblasHandle
();
hipStream_t
stream
=
Stream
();
ROCBLAS_RETURN_IF_ERROR
(
rocblas_set_stream
(
rocblas
,
stream
));
constexpr
size_t
element_size
=
sizeof
(
T
);
// TODO(tianleiwu): only calculate global index once per model instead of once per LongformerAttention node.
// Build Global Index
auto
global_index_buffer
=
GetScratchBuffer
<
int
>
(
static_cast
<
size_t
>
(
batch_size
)
*
sequence_length
);
auto
batch_global_num_buffer
=
GetScratchBuffer
<
int
>
(
batch_size
);
size_t
global_scratch_bytes
=
GetGlobalScratchSize
(
sequence_length
);
auto
global_scratch_buffer
=
GetScratchBuffer
<
void
>
(
global_scratch_bytes
);
auto
&
device_prop
=
GetDeviceProp
();
ORT_RETURN_IF_ERROR
(
BuildGlobalIndex
(
device_prop
,
stream
,
global_attention_mask
->
Data
<
int
>
(),
batch_size
,
sequence_length
,
global_index_buffer
.
get
(),
batch_global_num_buffer
.
get
(),
global_scratch_buffer
.
get
(),
global_scratch_bytes
));
// Copy batch_global_num to CPU
size_t
pinned_buffer_bytes
=
GetPinnedBufferSize
(
batch_size
);
auto
pinned_buffer
=
AllocateBufferOnCPUPinned
<
void
>
(
pinned_buffer_bytes
);
int
*
batch_global_num_pinned
=
reinterpret_cast
<
int
*>
(
pinned_buffer
.
get
());
HIP_RETURN_IF_ERROR
(
hipMemcpyAsync
(
batch_global_num_pinned
,
batch_global_num_buffer
.
get
(),
batch_size
*
sizeof
(
int
),
hipMemcpyDeviceToHost
,
stream
));
// Create an event to make sure the async copy is finished before reading the data.
AutoDestoryCudaEvent
new_event
;
hipEvent_t
&
is_copy_done
=
new_event
.
Get
();
HIP_RETURN_IF_ERROR
(
hipEventCreateWithFlags
(
&
is_copy_done
,
hipEventDisableTiming
));
HIP_RETURN_IF_ERROR
(
hipEventRecord
(
is_copy_done
,
stream
));
size_t
qkv_size
=
batch_size
*
sequence_length
*
3
*
hidden_size
*
element_size
;
// Buffer for GEMM outputs of q, k, v, global_q, global_k and global_v
// TODO(tianleiwu): compact global_q only need batch_size * window * hidden_size * element_size buffer size.
auto
gemm_buffer
=
GetScratchBuffer
<
void
>
(
qkv_size
+
qkv_size
);
bool
use_merged_qkv_weights
=
(
weights
->
Shape
().
NumDimensions
()
==
2
);
int
m
=
batch_size
*
sequence_length
;
int
n
=
use_merged_qkv_weights
?
3
*
hidden_size
:
hidden_size
;
int
k
=
hidden_size
;
typedef
typename
ToHipType
<
T
>::
MappedType
HipT
;
const
HipT
*
input_data
=
reinterpret_cast
<
const
HipT
*>
(
input
->
Data
<
T
>
());
const
HipT
*
weights_data
=
reinterpret_cast
<
const
HipT
*>
(
weights
->
Data
<
T
>
());
const
HipT
*
global_weights_data
=
reinterpret_cast
<
const
HipT
*>
(
global_weights
->
Data
<
T
>
());
float
one
=
1.0
f
;
float
zero
=
0.0
f
;
if
(
use_merged_qkv_weights
)
{
// Gemm, note that ROCM assumes col-major, so result(N, M) = 1 * weights x input + 0 x B.
ROCBLAS_RETURN_IF_ERROR
(
rocblasGemmHelper
(
rocblas
,
rocblas_operation_none
,
rocblas_operation_none
,
n
,
m
,
k
,
&
one
,
weights_data
,
n
,
input_data
,
k
,
&
zero
,
reinterpret_cast
<
HipT
*>
(
gemm_buffer
.
get
()),
n
,
device_prop
));
}
else
{
// q
const
HipT
*
q_weight
=
weights_data
;
HipT
*
q_data
=
reinterpret_cast
<
HipT
*>
(
gemm_buffer
.
get
());
ROCBLAS_RETURN_IF_ERROR
(
rocblasGemmHelper
(
rocblas
,
rocblas_operation_none
,
rocblas_operation_none
,
n
,
m
,
k
,
&
one
,
q_weight
,
n
,
input_data
,
k
,
&
zero
,
q_data
,
n
,
device_prop
));
// k
const
HipT
*
k_weight
=
q_weight
+
hidden_size
*
hidden_size
;
HipT
*
k_data
=
q_data
+
batch_size
*
sequence_length
*
hidden_size
;
ROCBLAS_RETURN_IF_ERROR
(
rocblasGemmHelper
(
rocblas
,
rocblas_operation_none
,
rocblas_operation_none
,
n
,
m
,
k
,
&
one
,
k_weight
,
n
,
input_data
,
k
,
&
zero
,
k_data
,
n
,
device_prop
));
// v
const
HipT
*
v_weight
=
k_weight
+
hidden_size
*
hidden_size
;
HipT
*
v_data
=
k_data
+
batch_size
*
sequence_length
*
hidden_size
;
ROCBLAS_RETURN_IF_ERROR
(
rocblasGemmHelper
(
rocblas
,
rocblas_operation_none
,
rocblas_operation_none
,
n
,
m
,
k
,
&
one
,
v_weight
,
n
,
input_data
,
k
,
&
zero
,
v_data
,
n
,
device_prop
));
}
// Wait for async copy of batch_global_num
HIP_RETURN_IF_ERROR
(
hipEventSynchronize
(
is_copy_done
));
// Find the maximum number of global tokens in all batches
int
max_num_global
=
0
;
for
(
int
i
=
0
;
i
<
batch_size
;
++
i
)
{
if
(
max_num_global
<
batch_global_num_pinned
[
i
])
{
max_num_global
=
batch_global_num_pinned
[
i
];
}
}
// Do not use compact memory kernel in the following situations:
// (1) global tokens > windows size, compact memory kernel cannot be used due to its assumptions.
// (2) sequence_length == 2 * attention_window, compact memory kernel has parity issue.
// (3) user sets environment variable ORT_LONGFORMER_COMPACT_MEMORY=0
bool
disable_compact_memory
=
(
max_num_global
>
window_
||
sequence_length
==
2
*
window_
||
!
use_compact_memory_
);
// Fully connection for global projection.
// Note that Q only need handle global query tokens if we split GEMM to global Q/K/V separately.
// When there is no global token, need not run global GEMM.
HipT
*
global_gemm_buffer
=
nullptr
;
if
(
max_num_global
>
0
)
{
global_gemm_buffer
=
reinterpret_cast
<
HipT
*>
(
reinterpret_cast
<
char
*>
(
gemm_buffer
.
get
())
+
qkv_size
);
if
(
use_merged_qkv_weights
)
{
ROCBLAS_RETURN_IF_ERROR
(
rocblasGemmHelper
(
rocblas
,
rocblas_operation_none
,
rocblas_operation_none
,
n
,
m
,
k
,
&
one
,
reinterpret_cast
<
const
HipT
*>
(
global_weights
->
Data
<
T
>
()),
n
,
input_data
,
k
,
&
zero
,
global_gemm_buffer
,
n
,
device_prop
));
}
else
{
// global q
const
HipT
*
global_q_weight
=
global_weights_data
;
HipT
*
global_q
=
global_gemm_buffer
+
2
*
batch_size
*
sequence_length
*
hidden_size
;
if
(
disable_compact_memory
)
{
ROCBLAS_RETURN_IF_ERROR
(
rocblasGemmHelper
(
rocblas
,
rocblas_operation_none
,
rocblas_operation_none
,
n
,
m
,
k
,
&
one
,
global_q_weight
,
n
,
input_data
,
k
,
&
zero
,
global_q
,
n
,
device_prop
));
}
else
{
ROCBLAS_RETURN_IF_ERROR
(
rocblasGemmStridedBatchedHelper
(
rocblas
,
rocblas_operation_none
,
rocblas_operation_none
,
hidden_size
,
// m
max_num_global
,
// n
hidden_size
,
// k
&
one
,
// alpha
global_q_weight
,
// A
hidden_size
,
// lda
0
,
// strideA
input_data
,
// B
hidden_size
,
// ldb
sequence_length
*
hidden_size
,
// strideB
&
zero
,
// beta
global_q
,
// C
hidden_size
,
// ldc
max_num_global
*
hidden_size
,
// strideC
batch_size
,
// batch count
device_prop
));
}
// global k
const
HipT
*
global_k_weight
=
global_weights_data
+
hidden_size
*
hidden_size
;
HipT
*
global_k
=
global_gemm_buffer
;
ROCBLAS_RETURN_IF_ERROR
(
rocblasGemmHelper
(
rocblas
,
rocblas_operation_none
,
rocblas_operation_none
,
n
,
m
,
k
,
&
one
,
global_k_weight
,
n
,
input_data
,
k
,
&
zero
,
global_k
,
n
,
device_prop
));
// global v
const
HipT
*
global_v_weight
=
global_k_weight
+
hidden_size
*
hidden_size
;
HipT
*
global_v
=
global_gemm_buffer
+
batch_size
*
sequence_length
*
hidden_size
;
ROCBLAS_RETURN_IF_ERROR
(
rocblasGemmHelper
(
rocblas
,
rocblas_operation_none
,
rocblas_operation_none
,
n
,
m
,
k
,
&
one
,
global_v_weight
,
n
,
input_data
,
k
,
&
zero
,
global_v
,
n
,
device_prop
));
}
}
size_t
workSpaceSize
=
GetLongformerAttentionWorkspaceSize
(
element_size
,
batch_size
,
num_heads_
,
head_size
,
sequence_length
,
max_num_global
,
window_
,
disable_compact_memory
);
auto
workspace_buffer
=
GetScratchBuffer
<
void
>
(
workSpaceSize
);
ORT_RETURN_IF_ERROR
(
LaunchLongformerAttentionKernel
(
device_prop
,
rocblas
,
stream
,
reinterpret_cast
<
const
HipT
*>
(
gemm_buffer
.
get
()),
reinterpret_cast
<
const
HipT
*>
(
bias
->
Data
<
T
>
()),
reinterpret_cast
<
const
HipT
*>
(
attention_mask
->
Data
<
T
>
()),
reinterpret_cast
<
const
HipT
*>
(
global_gemm_buffer
),
reinterpret_cast
<
const
HipT
*>
(
global_bias
->
Data
<
T
>
()),
global_attention_mask
->
Data
<
int
>
(),
global_index_buffer
.
get
(),
batch_global_num_buffer
.
get
(),
pinned_buffer
.
get
(),
workspace_buffer
.
get
(),
output
->
MutableData
<
T
>
(),
batch_size
,
sequence_length
,
num_heads_
,
head_size
,
window_
,
max_num_global
,
element_size
,
disable_compact_memory
,
use_merged_qkv_weights
,
use_half4_
));
// Defer release of pinned memory since hipStreamSynchronize is not used here and kernel need access the buffer.
this
->
AddDeferredReleaseCPUPtr
(
pinned_buffer
.
release
());
return
Status
::
OK
();
}
}
// namespace rocm
}
// namespace contrib
}
// namespace onnxruntime
build/Linux/Release/amdgpu/onnxruntime/contrib_ops/rocm/bert/longformer_attention.h
0 → 100644
View file @
1a91fcc2
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include "core/common/common.h"
#include "core/providers/rocm/rocm_kernel.h"
#include "contrib_ops/cpu/bert/longformer_attention_base.h"
namespace
onnxruntime
{
namespace
contrib
{
namespace
rocm
{
using
namespace
onnxruntime
::
rocm
;
template
<
typename
T
>
class
LongformerAttention
final
:
public
RocmKernel
,
public
LongformerAttentionBase
{
public:
LongformerAttention
(
const
OpKernelInfo
&
info
);
Status
ComputeInternal
(
OpKernelContext
*
context
)
const
override
;
private:
bool
use_compact_memory_
;
bool
use_half4_
;
};
}
// namespace rocm
}
// namespace contrib
}
// namespace onnxruntime
build/Linux/Release/amdgpu/onnxruntime/contrib_ops/rocm/bert/longformer_attention_impl.cu
0 → 100644
View file @
1a91fcc2
This diff is collapsed.
Click to expand it.
build/Linux/Release/amdgpu/onnxruntime/contrib_ops/rocm/bert/longformer_attention_impl.h
0 → 100644
View file @
1a91fcc2
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include "core/providers/rocm/shared_inc/rocm_utils.h"
namespace
onnxruntime
{
namespace
contrib
{
namespace
rocm
{
size_t
GetPinnedBufferSize
(
size_t
batch_size
);
size_t
GetLongformerAttentionWorkspaceSize
(
size_t
element_size
,
size_t
batch_size
,
size_t
num_heads
,
size_t
head_size
,
size_t
sequence_length
,
size_t
max_num_global
,
size_t
window
,
bool
disable_compact_memory
);
Status
LaunchLongformerAttentionKernel
(
const
hipDeviceProp_t
&
device_prop
,
// Device Properties
rocblas_handle
rocblas
,
// Rocblas handle
hipStream_t
stream
,
// ROCM stream
const
void
*
input
,
// Input tensor
const
void
*
bias
,
// Bias tensor
const
void
*
attention_mask
,
// Attention mask with shape (B, S)
const
void
*
global_input
,
// Global attention input, or nullptr when max_num_global == 0.
const
void
*
global_bias
,
// Global bias tensor
const
int
*
global_attention
,
// Global attention flags with shape (B, S)
const
int
*
global_index
,
// Global index
const
int
*
batch_global_num
,
// Number of global tokens per batch. It is in device memory.
void
*
pinned_buffer
,
// Pinned memory: copy of batch_global_num, and a buffer to copy to scratch2.
void
*
workspace
,
// Temporary buffer
void
*
output
,
// Output tensor
int
batch_size
,
// Batch size (B)
int
sequence_length
,
// Sequence length (S)
int
num_heads
,
// Number of attention heads (N)
int
head_size
,
// Hidden layer size per head (H)
int
window
,
// One sided attention window (W)
int
max_num_global
,
// Maximum number of global tokens (G)
const
size_t
element_size
,
// Element size of input tensor,
bool
disable_compact_memory
,
// Disable compact memory kernel
bool
use_merged_qkv_weights
,
bool
use_half4
);
}
// namespace rocm
}
// namespace contrib
}
// namespace onnxruntime
build/Linux/Release/amdgpu/onnxruntime/contrib_ops/rocm/bert/longformer_attention_softmax.cu
0 → 100644
View file @
1a91fcc2
This diff is collapsed.
Click to expand it.
Prev
1
2
3
4
5
…
14
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment