Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
onnxruntime_v14
Commits
1a91fcc2
Commit
1a91fcc2
authored
Jul 25, 2023
by
gaoqiong
Browse files
add dtk所需文件
parent
a144865d
Pipeline
#492
failed with stages
in 0 seconds
Changes
280
Pipelines
1
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
4579 additions
and
0 deletions
+4579
-0
build/Linux/Release/amdgpu/onnxruntime/contrib_ops/rocm/activation/activations.cc
...pu/onnxruntime/contrib_ops/rocm/activation/activations.cc
+61
-0
build/Linux/Release/amdgpu/onnxruntime/contrib_ops/rocm/activation/activations.h
...gpu/onnxruntime/contrib_ops/rocm/activation/activations.h
+96
-0
build/Linux/Release/amdgpu/onnxruntime/contrib_ops/rocm/activation/activations_impl.cu
...nxruntime/contrib_ops/rocm/activation/activations_impl.cu
+90
-0
build/Linux/Release/amdgpu/onnxruntime/contrib_ops/rocm/activation/activations_impl.h
...nnxruntime/contrib_ops/rocm/activation/activations_impl.h
+30
-0
build/Linux/Release/amdgpu/onnxruntime/contrib_ops/rocm/aten_ops/aten_op.cc
...e/amdgpu/onnxruntime/contrib_ops/rocm/aten_ops/aten_op.cc
+18
-0
build/Linux/Release/amdgpu/onnxruntime/contrib_ops/rocm/bert/add_bias_transpose.cu
...u/onnxruntime/contrib_ops/rocm/bert/add_bias_transpose.cu
+459
-0
build/Linux/Release/amdgpu/onnxruntime/contrib_ops/rocm/bert/add_bias_transpose.h
...pu/onnxruntime/contrib_ops/rocm/bert/add_bias_transpose.h
+43
-0
build/Linux/Release/amdgpu/onnxruntime/contrib_ops/rocm/bert/attention_concat.cu
...gpu/onnxruntime/contrib_ops/rocm/bert/attention_concat.cu
+250
-0
build/Linux/Release/amdgpu/onnxruntime/contrib_ops/rocm/bert/attention_impl.h
...amdgpu/onnxruntime/contrib_ops/rocm/bert/attention_impl.h
+150
-0
build/Linux/Release/amdgpu/onnxruntime/contrib_ops/rocm/bert/attention_transpose.cu
.../onnxruntime/contrib_ops/rocm/bert/attention_transpose.cu
+304
-0
build/Linux/Release/amdgpu/onnxruntime/contrib_ops/rocm/bert/bert_padding.cu
.../amdgpu/onnxruntime/contrib_ops/rocm/bert/bert_padding.cu
+343
-0
build/Linux/Release/amdgpu/onnxruntime/contrib_ops/rocm/bert/bert_padding.h
...e/amdgpu/onnxruntime/contrib_ops/rocm/bert/bert_padding.h
+48
-0
build/Linux/Release/amdgpu/onnxruntime/contrib_ops/rocm/bert/decoder_attention.cc
...pu/onnxruntime/contrib_ops/rocm/bert/decoder_attention.cc
+392
-0
build/Linux/Release/amdgpu/onnxruntime/contrib_ops/rocm/bert/decoder_attention.h
...gpu/onnxruntime/contrib_ops/rocm/bert/decoder_attention.h
+26
-0
build/Linux/Release/amdgpu/onnxruntime/contrib_ops/rocm/bert/layer_norm.cuh
...e/amdgpu/onnxruntime/contrib_ops/rocm/bert/layer_norm.cuh
+158
-0
build/Linux/Release/amdgpu/onnxruntime/contrib_ops/rocm/bert/longformer_attention.cc
...onnxruntime/contrib_ops/rocm/bert/longformer_attention.cc
+295
-0
build/Linux/Release/amdgpu/onnxruntime/contrib_ops/rocm/bert/longformer_attention.h
.../onnxruntime/contrib_ops/rocm/bert/longformer_attention.h
+29
-0
build/Linux/Release/amdgpu/onnxruntime/contrib_ops/rocm/bert/longformer_attention_impl.cu
...untime/contrib_ops/rocm/bert/longformer_attention_impl.cu
+1059
-0
build/Linux/Release/amdgpu/onnxruntime/contrib_ops/rocm/bert/longformer_attention_impl.h
...runtime/contrib_ops/rocm/bert/longformer_attention_impl.h
+53
-0
build/Linux/Release/amdgpu/onnxruntime/contrib_ops/rocm/bert/longformer_attention_softmax.cu
...ime/contrib_ops/rocm/bert/longformer_attention_softmax.cu
+675
-0
No files found.
build/Linux/Release/amdgpu/onnxruntime/contrib_ops/rocm/activation/activations.cc
0 → 100644
View file @
1a91fcc2
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "activations.h"
#include "core/framework/op_kernel.h"
using
namespace
onnxruntime
::
rocm
;
namespace
onnxruntime
{
namespace
contrib
{
namespace
rocm
{
#define REGISTER_ACTIVATION_KERNEL(x, ver, domain, T) \
ONNX_OPERATOR_TYPED_KERNEL_EX( \
x, \
domain, \
ver, \
T, \
kRocmExecutionProvider, \
(*KernelDefBuilder::Create()) \
.TypeConstraint("T", DataTypeImpl::GetTensorType<T>()) \
.MayInplace(0, 0), \
x<T>);
#define UNARY_ACTIVATION_COMPUTE(x, T) \
template <> \
Status x<T>::ComputeInternal(OpKernelContext* context) const { \
UnaryElementwisePreparation p; \
ORT_RETURN_IF_ERROR(UnaryElementwise::Prepare(context, &p)); \
Ctx##x func_ctx = MakeFuncCtx(); \
Impl_##x<typename ToHipType<T>::MappedType>( \
Stream(), \
reinterpret_cast<const typename ToHipType<T>::MappedType*>(p.input_tensor->Data<T>()), \
reinterpret_cast<typename ToHipType<T>::MappedType*>(p.output_tensor->MutableData<T>()), \
&func_ctx, p.output_tensor->Shape().Size()); \
\
return Status::OK(); \
}
#define UNARY_ACTIVATION_OP_TYPED(name, ver, domain, T) \
REGISTER_ACTIVATION_KERNEL(name, ver, domain, T) \
UNARY_ACTIVATION_COMPUTE(name, T)
#define UNARY_ACTIVATION_OP_HFD(name, ver, domain) \
UNARY_ACTIVATION_OP_TYPED(name, ver, domain, MLFloat16) \
UNARY_ACTIVATION_OP_TYPED(name, ver, domain, float) \
UNARY_ACTIVATION_OP_TYPED(name, ver, domain, double)
UNARY_ACTIVATION_OP_HFD
(
Affine
,
1
,
kOnnxDomain
);
UNARY_ACTIVATION_OP_HFD
(
ParametricSoftplus
,
1
,
kOnnxDomain
);
UNARY_ACTIVATION_OP_HFD
(
ScaledTanh
,
1
,
kOnnxDomain
);
UNARY_ACTIVATION_OP_HFD
(
Gelu
,
1
,
kMSDomain
);
UNARY_ACTIVATION_OP_HFD
(
QuickGelu
,
1
,
kMSDomain
);
REGISTER_ACTIVATION_KERNEL
(
ThresholdedRelu
,
1
,
kOnnxDomain
,
MLFloat16
)
REGISTER_ACTIVATION_KERNEL
(
ThresholdedRelu
,
1
,
kOnnxDomain
,
float
)
REGISTER_ACTIVATION_KERNEL
(
ThresholdedRelu
,
1
,
kOnnxDomain
,
double
)
}
//namespace rocm
}
// namespace contrib
}
// namespace onnxruntime
build/Linux/Release/amdgpu/onnxruntime/contrib_ops/rocm/activation/activations.h
0 → 100644
View file @
1a91fcc2
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include "core/providers/rocm/rocm_common.h"
#include "core/providers/rocm/math/unary_elementwise_ops.h"
#include "core/providers/rocm/math/binary_elementwise_ops.h"
#include "core/providers/rocm/activation/activations.h"
#include "activations_impl.h"
using
namespace
onnxruntime
::
rocm
;
namespace
onnxruntime
{
namespace
contrib
{
namespace
rocm
{
template
<
typename
T
>
class
Affine
final
:
public
UnaryElementwise
{
public:
Affine
(
const
OpKernelInfo
&
info
)
:
UnaryElementwise
(
info
)
{
ORT_ENFORCE
(
info
.
GetAttr
(
"alpha"
,
&
alpha_
).
IsOK
());
ORT_ENFORCE
(
info
.
GetAttr
(
"beta"
,
&
beta_
).
IsOK
());
}
Status
ComputeInternal
(
OpKernelContext
*
context
)
const
override
;
private:
MAKE_FUNC_CTX_ALPHA_BETA
()
float
alpha_
;
float
beta_
;
};
template
<
typename
T
>
class
ParametricSoftplus
final
:
public
UnaryElementwise
{
public:
ParametricSoftplus
(
const
OpKernelInfo
&
info
)
:
UnaryElementwise
(
info
)
{
ORT_ENFORCE
(
info
.
GetAttr
(
"alpha"
,
&
alpha_
).
IsOK
());
ORT_ENFORCE
(
info
.
GetAttr
(
"beta"
,
&
beta_
).
IsOK
());
}
Status
ComputeInternal
(
OpKernelContext
*
context
)
const
override
;
private:
MAKE_FUNC_CTX_ALPHA_BETA
()
float
alpha_
;
float
beta_
;
};
template
<
typename
T
>
class
ScaledTanh
final
:
public
UnaryElementwise
{
public:
ScaledTanh
(
const
OpKernelInfo
&
info
)
:
UnaryElementwise
(
info
)
{
ORT_ENFORCE
(
info
.
GetAttr
(
"alpha"
,
&
alpha_
).
IsOK
());
ORT_ENFORCE
(
info
.
GetAttr
(
"beta"
,
&
beta_
).
IsOK
());
}
Status
ComputeInternal
(
OpKernelContext
*
context
)
const
override
;
private:
MAKE_FUNC_CTX_ALPHA_BETA
()
float
alpha_
;
float
beta_
;
};
template
<
typename
T
>
class
Gelu
final
:
public
UnaryElementwise
{
public:
Gelu
(
const
OpKernelInfo
&
info
)
:
UnaryElementwise
(
info
)
{}
Status
ComputeInternal
(
OpKernelContext
*
context
)
const
override
;
private:
MAKE_FUNC_CTX_NULL
()
};
template
<
typename
T
>
class
QuickGelu
final
:
public
UnaryElementwise
{
public:
QuickGelu
(
const
OpKernelInfo
&
info
)
:
UnaryElementwise
(
info
)
{
alpha_
=
info
.
GetAttrOrDefault
<
float
>
(
"alpha"
,
1.702
f
);
}
Status
ComputeInternal
(
OpKernelContext
*
context
)
const
override
;
private:
MAKE_FUNC_CTX_ALPHA
()
float
alpha_
;
};
}
// namespace rocm
}
// namespace contrib
}
// namespace onnxruntime
build/Linux/Release/amdgpu/onnxruntime/contrib_ops/rocm/activation/activations_impl.cu
0 → 100644
View file @
1a91fcc2
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include <hip/hip_runtime.h>
#include "activations_impl.h"
#include "core/providers/rocm/cu_inc/common.cuh"
#include "core/providers/rocm/cu_inc/unary_elementwise_impl.cuh"
using
namespace
onnxruntime
::
rocm
;
namespace
onnxruntime
{
namespace
contrib
{
namespace
rocm
{
template
<
typename
T
>
struct
OP_Affine
:
public
CtxAffine
{
__device__
__inline__
T
operator
()(
const
T
&
a
)
const
{
return
a
*
(
T
)
alpha
+
(
T
)
beta
;
}
};
template
<
typename
T
>
struct
OP_ParametricSoftplus
:
public
CtxParametricSoftplus
{
__device__
__inline__
T
operator
()(
const
T
&
a
)
const
{
if
(
a
>
(
T
)
0
)
return
(
T
)
alpha
*
(
a
*
(
T
)
beta
+
_Log
(
_Exp
(
-
a
*
(
T
)
beta
)
+
(
T
)
1
));
else
return
(
T
)
alpha
*
_Log
(
_Exp
(
a
*
(
T
)
beta
)
+
(
T
)
1
);
}
};
template
<
typename
T
>
struct
OP_ScaledTanh
:
public
CtxScaledTanh
{
__device__
__inline__
T
operator
()(
const
T
&
a
)
const
{
return
(
T
)
alpha
*
_Tanh
(
a
*
(
T
)
beta
);
}
};
template
<
typename
T
>
struct
OP_Gelu
:
public
CtxGelu
{
__device__
__inline__
T
operator
()(
const
T
&
a
)
const
{
return
_Gelu
(
a
);
}
};
template
<
>
struct
OP_Gelu
<
half
>
:
public
CtxGelu
{
__device__
__inline__
half
operator
()(
const
half
&
a
)
const
{
return
static_cast
<
half
>
(
_Gelu
(
static_cast
<
float
>
(
a
)));
}
};
template
<
typename
T
>
struct
OP_QuickGelu
:
public
CtxQuickGelu
{
__device__
__inline__
T
operator
()(
const
T
&
a
)
const
{
T
v
=
a
*
static_cast
<
T
>
(
alpha
);
T
one
=
static_cast
<
T
>
(
1.
f
);
T
zero
=
static_cast
<
T
>
(
0.
f
);
T
sigmoid
=
v
>=
zero
?
one
/
(
one
+
_Exp
(
-
v
))
:
one
-
one
/
(
one
+
_Exp
(
v
));
return
a
*
sigmoid
;
}
};
#define UNARY_ACTIVATION_IMPL(name) \
UNARY_ACTIVATION_IMPL_DECLARATION(name) { \
UnaryElementWiseImpl(stream, \
input_data, \
output_data, \
*reinterpret_cast<const OP_##name<T>*>(func_ctx), \
count); \
}
#define SPECIALIZED_UNARY_ACTIVATION_IMPL(name, T) \
template void Impl_##name<T>(hipStream_t stream, const T* input_data, T* output_data, const Ctx##name* func_ctx, size_t count);
#define SPECIALIZED_UNARY_ACTIVATIONL_HFD(name) \
SPECIALIZED_UNARY_ACTIVATION_IMPL(name, half) \
SPECIALIZED_UNARY_ACTIVATION_IMPL(name, float) \
SPECIALIZED_UNARY_ACTIVATION_IMPL(name, double)
#define UNARY_ACTIVATION_OP_NAME(name) \
UNARY_ACTIVATION_IMPL(name); \
SPECIALIZED_UNARY_ACTIVATIONL_HFD(name)
UNARY_CONTRIB_ACTIVATION_OPS
()
#undef UNARY_ACTIVATION_OP_NAME
}
// namespace rocm
}
// namespace contrib
}
// namespace onnxruntime
build/Linux/Release/amdgpu/onnxruntime/contrib_ops/rocm/activation/activations_impl.h
0 → 100644
View file @
1a91fcc2
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include "core/providers/rocm/activation/activations_impl.h"
namespace
onnxruntime
{
namespace
contrib
{
namespace
rocm
{
typedef
onnxruntime
::
rocm
::
CtxAlphaBeta
CtxAffine
;
typedef
onnxruntime
::
rocm
::
CtxAlphaBeta
CtxParametricSoftplus
;
typedef
onnxruntime
::
rocm
::
CtxAlphaBeta
CtxScaledTanh
;
typedef
onnxruntime
::
rocm
::
CtxNull
CtxGelu
;
typedef
onnxruntime
::
rocm
::
CtxAlpha
CtxQuickGelu
;
#define UNARY_CONTRIB_ACTIVATION_OPS() \
UNARY_ACTIVATION_OP_NAME(ScaledTanh) \
UNARY_ACTIVATION_OP_NAME(Affine) \
UNARY_ACTIVATION_OP_NAME(ParametricSoftplus) \
UNARY_ACTIVATION_OP_NAME(Gelu) \
UNARY_ACTIVATION_OP_NAME(QuickGelu)
#define UNARY_ACTIVATION_OP_NAME(name) UNARY_ACTIVATION_IMPL_DECLARATION(name);
UNARY_CONTRIB_ACTIVATION_OPS
()
#undef UNARY_ACTIVATION_OP_NAME
}
// namespace rocm
}
// namespace contrib
}
// namespace onnxruntime
build/Linux/Release/amdgpu/onnxruntime/contrib_ops/rocm/aten_ops/aten_op.cc
0 → 100644
View file @
1a91fcc2
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "core/providers/shared_library/provider_api.h"
#include "contrib_ops/cpu/aten_ops/aten_op.h"
namespace
onnxruntime
{
namespace
contrib
{
namespace
rocm
{
ONNX_OPERATOR_KERNEL_EX
(
ATen
,
kPytorchAtenDomain
,
1
,
kRocmExecutionProvider
,
(
*
KernelDefBuilder
::
Create
()).
TypeConstraint
(
"T"
,
DataTypeImpl
::
AllTensorAndSequenceTensorTypes
()),
onnxruntime
::
contrib
::
ATen
);
}
// namespace rocm
}
// namespace contrib
}
// namespace onnxruntime
build/Linux/Release/amdgpu/onnxruntime/contrib_ops/rocm/bert/add_bias_transpose.cu
0 → 100644
View file @
1a91fcc2
#include "hip/hip_runtime.h"
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "core/providers/rocm/rocm_common.h"
#include "core/providers/rocm/cu_inc/common.cuh"
#include "contrib_ops/rocm/bert/add_bias_transpose.h"
namespace
onnxruntime
{
namespace
rocm
{
struct
__align__
(
8
)
Half4
{
half2
x
;
half2
y
;
};
__device__
__forceinline__
Half4
operator
+
(
const
Half4
&
a
,
const
Half4
&
b
)
{
Half4
r
;
r
.
x
=
a
.
x
+
b
.
x
;
r
.
y
=
a
.
y
+
b
.
y
;
return
r
;
}
__device__
__forceinline__
float2
operator
+
(
const
float2
&
a
,
const
float2
&
b
)
{
return
make_float2
(
a
.
x
+
b
.
x
,
a
.
y
+
b
.
y
);
}
__device__
__forceinline__
float4
operator
+
(
const
float4
&
a
,
const
float4
&
b
)
{
return
make_float4
(
a
.
x
+
b
.
x
,
a
.
y
+
b
.
y
,
a
.
z
+
b
.
z
,
a
.
w
+
b
.
w
);
}
}
// namespace rocm
}
// namespace onnxruntime
using
namespace
onnxruntime
::
rocm
;
namespace
onnxruntime
{
namespace
contrib
{
namespace
rocm
{
template
<
typename
T
>
__global__
void
AddBiasTransposeTrt
(
const
T
*
input
,
const
T
*
biases
,
T
*
output
)
{
// Input: BxSxMxNxH (Format 2)
// Output: BxSxNxMxH
// B is batch_size, S is sequence_length, M is number of matrices, N is num_heads, H is head_size
int
n
=
threadIdx
.
y
;
int
s
=
blockIdx
.
x
;
int
b
=
blockIdx
.
y
;
int
m
=
blockIdx
.
z
;
// matrix id
const
int
H
=
blockDim
.
x
;
const
int
N
=
blockDim
.
y
;
const
int
S
=
gridDim
.
x
;
const
int
M
=
gridDim
.
z
;
const
int
NH
=
N
*
H
;
const
int
offset
=
(
b
*
S
+
s
)
*
M
*
NH
;
const
int
in_offset
=
offset
+
m
*
NH
+
n
*
H
;
const
int
out_offset
=
offset
+
(
n
*
M
+
m
)
*
H
;
const
int
h
=
threadIdx
.
x
;
if
(
h
<
H
)
{
output
[
out_offset
+
h
]
=
input
[
in_offset
+
h
]
+
biases
[
m
*
NH
+
n
*
H
+
h
];
}
}
template
<
typename
T
>
__global__
void
AddBiasTransposeTrtLarge
(
const
int
head_size
,
const
T
*
input
,
const
T
*
biases
,
T
*
output
)
{
int
n
=
threadIdx
.
y
;
int
s
=
blockIdx
.
x
;
int
b
=
blockIdx
.
y
;
int
m
=
blockIdx
.
z
;
const
int
stride
=
blockDim
.
x
;
const
int
H
=
head_size
;
const
int
N
=
blockDim
.
y
;
const
int
S
=
gridDim
.
x
;
const
int
M
=
gridDim
.
z
;
const
int
NH
=
N
*
H
;
const
int
offset
=
(
b
*
S
+
s
)
*
M
*
NH
;
const
int
in_offset
=
offset
+
m
*
NH
+
n
*
H
;
const
int
out_offset
=
offset
+
(
n
*
M
+
m
)
*
H
;
int
h
=
threadIdx
.
x
;
while
(
h
<
H
)
{
output
[
out_offset
+
h
]
=
input
[
in_offset
+
h
]
+
biases
[
m
*
NH
+
n
*
H
+
h
];
h
+=
stride
;
}
}
template
<
typename
T
>
__global__
void
AddBiasTransposeTrt
(
const
T
*
query
,
const
T
*
key
,
const
T
*
value
,
const
T
*
biases
,
T
*
output
)
{
// Q: BxSxNxH
// K: BxSxNxH
// V: BxSxNxH
// Output: BxSxNxMxH
// B is batch_size, S is sequence_length, M is number of matrices (3), N is num_heads, H is head_size
int
n
=
threadIdx
.
y
;
int
s
=
blockIdx
.
x
;
int
b
=
blockIdx
.
y
;
int
m
=
blockIdx
.
z
;
// matrix id
const
int
H
=
blockDim
.
x
;
const
int
N
=
blockDim
.
y
;
const
int
S
=
gridDim
.
x
;
const
int
M
=
gridDim
.
z
;
const
T
*
input
=
(
m
==
0
?
query
:
(
m
==
1
?
key
:
value
));
const
int
NH
=
N
*
H
;
const
int
in_offset
=
(
b
*
S
+
s
)
*
NH
+
n
*
H
;
const
int
out_offset
=
(
b
*
S
+
s
)
*
M
*
NH
+
(
n
*
M
+
m
)
*
H
;
const
int
h
=
threadIdx
.
x
;
if
(
h
<
H
)
{
output
[
out_offset
+
h
]
=
input
[
in_offset
+
h
]
+
biases
[
m
*
NH
+
n
*
H
+
h
];
}
}
template
<
typename
T
>
__global__
void
AddBiasTransposeTrtLarge
(
const
int
head_size
,
const
T
*
query
,
const
T
*
key
,
const
T
*
value
,
const
T
*
biases
,
T
*
output
)
{
int
n
=
threadIdx
.
y
;
int
s
=
blockIdx
.
x
;
int
b
=
blockIdx
.
y
;
int
m
=
blockIdx
.
z
;
// matrix id
const
int
stride
=
blockDim
.
x
;
const
int
H
=
head_size
;
const
int
N
=
blockDim
.
y
;
const
int
S
=
gridDim
.
x
;
const
int
M
=
gridDim
.
z
;
const
T
*
input
=
(
m
==
0
?
query
:
(
m
==
1
?
key
:
value
));
const
int
NH
=
N
*
H
;
const
int
in_offset
=
(
b
*
S
+
s
)
*
NH
+
n
*
H
;
const
int
out_offset
=
(
b
*
S
+
s
)
*
M
*
NH
+
(
n
*
M
+
m
)
*
H
;
int
h
=
threadIdx
.
x
;
if
(
h
<
H
)
{
output
[
out_offset
+
h
]
=
input
[
in_offset
+
h
]
+
biases
[
m
*
NH
+
n
*
H
+
h
];
h
+=
stride
;
}
}
template
<
typename
T
>
__global__
void
AddBiasTransposeQKV
(
const
T
*
input
,
const
T
*
biases
,
T
*
output
)
{
// Input: BxSxMxNxH (Format 1)
// Output: MxBxNxSxH
// B is batch_size, S is sequence_length, M is number of matrices, N is num_heads, H is head_size
int
n
=
threadIdx
.
y
;
int
s
=
blockIdx
.
x
;
int
b
=
blockIdx
.
y
;
int
m
=
blockIdx
.
z
;
// matrix id
const
int
head_size
=
blockDim
.
x
;
const
int
num_heads
=
blockDim
.
y
;
const
int
sequence_length
=
gridDim
.
x
;
const
int
batch_size
=
gridDim
.
y
;
const
int
M
=
gridDim
.
z
;
const
int
H
=
head_size
;
const
int
NH
=
num_heads
*
head_size
;
const
int
NHS
=
NH
*
sequence_length
;
int
in_offset
=
n
*
head_size
+
(
m
+
s
*
M
)
*
NH
+
b
*
NHS
*
M
;
const
int
out_offset
=
s
*
head_size
+
n
*
sequence_length
*
H
+
b
*
NHS
+
m
*
NHS
*
batch_size
;
const
int
h
=
threadIdx
.
x
;
if
(
h
<
head_size
)
{
output
[
out_offset
+
h
]
=
input
[
in_offset
+
h
]
+
biases
[
m
*
NH
+
n
*
H
+
h
];
}
}
template
<
typename
T
>
__global__
void
AddBiasTransposeQKV
(
const
T
*
input
,
const
T
*
biases
,
T
*
output
,
int
v_head_size
)
{
// Input: BxSxMxNxH (Format 1)
// Output: MxBxNxSxH
// B is batch_size, S is sequence_length, M is number of matrices, N is num_heads, H is head_size
int
n
=
threadIdx
.
y
;
// head_num_id
int
s
=
blockIdx
.
x
;
// sequence_id
int
b
=
blockIdx
.
y
;
// batch_id
int
m
=
blockIdx
.
z
;
// matrix id (Q=0, K=1, V=2)
const
int
h
=
threadIdx
.
x
;
// head_element_id
const
int
qk_head_size
=
blockDim
.
x
;
const
int
num_heads
=
blockDim
.
y
;
const
int
sequence_length
=
gridDim
.
x
;
const
int
batch_size
=
gridDim
.
y
;
const
int
head_size
=
(
m
==
2
?
v_head_size
:
qk_head_size
);
const
int
total_head_size
=
num_heads
*
(
qk_head_size
+
qk_head_size
+
v_head_size
);
int
in_offset
;
int
out_offset
;
int
bias_offset
;
in_offset
=
b
*
(
total_head_size
*
sequence_length
)
+
// B
s
*
(
total_head_size
)
+
// S
m
*
(
qk_head_size
*
num_heads
)
+
// M
n
*
head_size
+
// N
h
;
// H
out_offset
=
m
*
(
num_heads
*
qk_head_size
*
sequence_length
*
batch_size
)
+
// M
b
*
(
num_heads
*
head_size
*
sequence_length
)
+
// B
n
*
(
sequence_length
*
head_size
)
+
// N
s
*
(
head_size
)
+
// S
h
;
// H
bias_offset
=
m
*
(
num_heads
*
qk_head_size
)
+
// M
n
*
(
head_size
)
+
// N
h
;
// H
if
(
h
<
head_size
)
{
output
[
out_offset
]
=
input
[
in_offset
]
+
biases
[
bias_offset
];
}
}
template
<
typename
T
>
__global__
void
AddBiasTransposeQKVLarge
(
const
int
head_size
,
const
T
*
input
,
const
T
*
biases
,
T
*
output
)
{
int
n
=
threadIdx
.
y
;
int
s
=
blockIdx
.
x
;
int
b
=
blockIdx
.
y
;
int
m
=
blockIdx
.
z
;
const
int
stride
=
blockDim
.
x
;
const
int
num_heads
=
blockDim
.
y
;
const
int
sequence_length
=
gridDim
.
x
;
const
int
batch_size
=
gridDim
.
y
;
const
int
M
=
gridDim
.
z
;
const
int
H
=
head_size
;
const
int
NH
=
num_heads
*
H
;
const
int
NHS
=
NH
*
sequence_length
;
int
in_offset
=
n
*
H
+
(
m
+
s
*
M
)
*
NH
+
b
*
NHS
*
M
;
const
int
out_offset
=
s
*
H
+
n
*
sequence_length
*
H
+
b
*
NHS
+
m
*
NHS
*
batch_size
;
int
h
=
threadIdx
.
x
;
while
(
h
<
H
)
{
output
[
out_offset
+
h
]
=
input
[
in_offset
+
h
]
+
biases
[
m
*
NH
+
n
*
H
+
h
];
h
+=
stride
;
}
}
template
<
typename
T
>
__global__
void
AddBiasTranspose
(
const
T
*
input
,
const
T
*
biases
,
T
*
output
)
{
// Input: MxBxSxNxH (Format 0)
// Output: MxBxNxSxH
// B is batch_size, S is sequence_length, M is number of matrices, N is num_heads, H is head_size
int
n
=
threadIdx
.
y
;
int
s
=
blockIdx
.
x
;
int
b
=
blockIdx
.
y
;
int
m
=
blockIdx
.
z
;
const
int
head_size
=
blockDim
.
x
;
const
int
num_heads
=
blockDim
.
y
;
const
int
sequence_length
=
gridDim
.
x
;
const
int
batch_size
=
gridDim
.
y
;
const
int
H
=
head_size
;
const
int
NH
=
num_heads
*
head_size
;
const
int
NHS
=
NH
*
sequence_length
;
int
in_offset
=
n
*
H
+
s
*
NH
+
(
b
+
m
*
batch_size
)
*
NHS
;
const
int
out_offset
=
s
*
H
+
n
*
sequence_length
*
H
+
(
b
+
m
*
batch_size
)
*
NHS
;
const
int
h
=
threadIdx
.
x
;
if
(
h
<
head_size
)
{
output
[
out_offset
+
h
]
=
input
[
in_offset
+
h
]
+
biases
[
m
*
NH
+
n
*
H
+
h
];
}
}
template
<
typename
T
>
__global__
void
AddBiasTransposeLarge
(
const
int
head_size
,
const
T
*
input
,
const
T
*
biases
,
T
*
output
)
{
int
n
=
threadIdx
.
y
;
int
s
=
blockIdx
.
x
;
int
b
=
blockIdx
.
y
;
int
m
=
blockIdx
.
z
;
const
int
stride
=
blockDim
.
x
;
const
int
num_heads
=
blockDim
.
y
;
const
int
sequence_length
=
gridDim
.
x
;
const
int
batch_size
=
gridDim
.
y
;
const
int
H
=
head_size
;
const
int
NH
=
num_heads
*
H
;
const
int
NHS
=
NH
*
sequence_length
;
int
in_offset
=
n
*
H
+
s
*
NH
+
(
b
+
m
*
batch_size
)
*
NHS
;
const
int
out_offset
=
(
s
+
n
*
sequence_length
)
*
H
+
(
b
+
m
*
batch_size
)
*
NHS
;
int
h
=
threadIdx
.
x
;
while
(
h
<
H
)
{
output
[
out_offset
+
h
]
=
input
[
in_offset
+
h
]
+
biases
[
m
*
NH
+
n
*
H
+
h
];
h
+=
stride
;
}
}
template
<
typename
T
>
void
InvokeAddBiasTranspose
(
hipStream_t
stream
,
const
int
num_matrices
,
const
int
format
,
const
int
max_threads_per_block
,
const
int
batch_size
,
const
int
sequence_length
,
const
int
num_heads
,
const
int
qk_head_size
,
const
T
*
input
,
const
T
*
biases
,
T
*
output
,
const
int
v_head_size
)
{
const
dim3
grid
(
sequence_length
,
batch_size
,
num_matrices
);
if
(
qk_head_size
*
num_heads
<=
max_threads_per_block
)
{
const
dim3
block
(
qk_head_size
,
num_heads
,
1
);
if
(
format
==
2
)
{
hipLaunchKernelGGL
(
HIP_KERNEL_NAME
(
AddBiasTransposeTrt
<
T
>
),
grid
,
block
,
0
,
stream
,
input
,
biases
,
output
);
}
else
if
(
format
==
1
)
{
if
(
v_head_size
==
-
1
||
qk_head_size
==
v_head_size
)
{
hipLaunchKernelGGL
(
HIP_KERNEL_NAME
(
AddBiasTransposeQKV
<
T
>
),
grid
,
block
,
0
,
stream
,
input
,
biases
,
output
);
}
else
{
hipLaunchKernelGGL
(
HIP_KERNEL_NAME
(
AddBiasTransposeQKV
<
T
>
),
grid
,
block
,
0
,
stream
,
input
,
biases
,
output
,
v_head_size
);
}
}
else
{
hipLaunchKernelGGL
(
HIP_KERNEL_NAME
(
AddBiasTranspose
<
T
>
),
grid
,
block
,
0
,
stream
,
input
,
biases
,
output
);
}
}
else
{
const
dim3
block
(
CeilDiv
(
max_threads_per_block
,
num_heads
),
num_heads
,
1
);
if
(
format
==
2
)
{
hipLaunchKernelGGL
(
HIP_KERNEL_NAME
(
AddBiasTransposeTrtLarge
<
T
>
),
grid
,
block
,
0
,
stream
,
qk_head_size
,
input
,
biases
,
output
);
}
else
if
(
format
==
1
)
{
if
(
v_head_size
==
-
1
||
qk_head_size
==
v_head_size
)
{
hipLaunchKernelGGL
(
HIP_KERNEL_NAME
(
AddBiasTransposeQKVLarge
<
T
>
),
grid
,
block
,
0
,
stream
,
qk_head_size
,
input
,
biases
,
output
);
}
else
{
ORT_THROW
(
"AddBiasTranspose (format 1) not implemented for hidden_size > max_threads_per_block"
);
}
}
else
{
hipLaunchKernelGGL
(
HIP_KERNEL_NAME
(
AddBiasTransposeLarge
<
T
>
),
grid
,
block
,
0
,
stream
,
qk_head_size
,
input
,
biases
,
output
);
}
}
}
template
<
>
void
LaunchAddBiasTranspose
(
hipStream_t
stream
,
const
int
num_matrices
,
const
int
format
,
const
int
max_threads_per_block
,
const
int
batch_size
,
const
int
sequence_length
,
const
int
num_heads
,
const
int
qk_head_size
,
const
half
*
input
,
const
half
*
biases
,
half
*
output
,
bool
enable_half4
,
const
int
v_head_size
)
{
if
(
enable_half4
&&
0
==
(
qk_head_size
%
4
)
&&
0
==
(
v_head_size
%
4
))
{
const
int
H
=
qk_head_size
/
4
;
const
int
H_v
=
v_head_size
/
4
;
const
Half4
*
input2
=
reinterpret_cast
<
const
Half4
*>
(
input
);
const
Half4
*
biases2
=
reinterpret_cast
<
const
Half4
*>
(
biases
);
Half4
*
output2
=
reinterpret_cast
<
Half4
*>
(
output
);
InvokeAddBiasTranspose
<
Half4
>
(
stream
,
num_matrices
,
format
,
max_threads_per_block
,
batch_size
,
sequence_length
,
num_heads
,
H
,
input2
,
biases2
,
output2
,
H_v
);
}
else
if
(
0
==
(
qk_head_size
&
1
)
&&
0
==
(
v_head_size
%
1
))
{
const
int
H
=
qk_head_size
/
2
;
const
int
H_v
=
v_head_size
/
2
;
const
half2
*
input2
=
reinterpret_cast
<
const
half2
*>
(
input
);
const
half2
*
biases2
=
reinterpret_cast
<
const
half2
*>
(
biases
);
half2
*
output2
=
reinterpret_cast
<
half2
*>
(
output
);
InvokeAddBiasTranspose
<
half2
>
(
stream
,
num_matrices
,
format
,
max_threads_per_block
,
batch_size
,
sequence_length
,
num_heads
,
H
,
input2
,
biases2
,
output2
,
H_v
);
}
else
{
InvokeAddBiasTranspose
<
half
>
(
stream
,
num_matrices
,
format
,
max_threads_per_block
,
batch_size
,
sequence_length
,
num_heads
,
qk_head_size
,
input
,
biases
,
output
,
v_head_size
);
}
}
template
<
>
void
LaunchAddBiasTranspose
(
hipStream_t
stream
,
const
int
num_matrices
,
const
int
format
,
const
int
max_threads_per_block
,
const
int
batch_size
,
const
int
sequence_length
,
const
int
num_heads
,
const
int
qk_head_size
,
const
float
*
input
,
const
float
*
biases
,
float
*
output
,
bool
/*enable_half4*/
,
const
int
v_head_size
)
{
if
(
0
==
(
qk_head_size
%
4
))
{
const
int
H
=
qk_head_size
/
4
;
const
float4
*
input2
=
reinterpret_cast
<
const
float4
*>
(
input
);
const
float4
*
biases2
=
reinterpret_cast
<
const
float4
*>
(
biases
);
float4
*
output2
=
reinterpret_cast
<
float4
*>
(
output
);
InvokeAddBiasTranspose
<
float4
>
(
stream
,
num_matrices
,
format
,
max_threads_per_block
,
batch_size
,
sequence_length
,
num_heads
,
H
,
input2
,
biases2
,
output2
,
v_head_size
/
4
);
}
else
if
(
0
==
(
qk_head_size
&
1
))
{
const
int
H
=
qk_head_size
/
2
;
const
float2
*
input2
=
reinterpret_cast
<
const
float2
*>
(
input
);
const
float2
*
biases2
=
reinterpret_cast
<
const
float2
*>
(
biases
);
float2
*
output2
=
reinterpret_cast
<
float2
*>
(
output
);
InvokeAddBiasTranspose
<
float2
>
(
stream
,
num_matrices
,
format
,
max_threads_per_block
,
batch_size
,
sequence_length
,
num_heads
,
H
,
input2
,
biases2
,
output2
,
v_head_size
/
2
);
}
else
{
InvokeAddBiasTranspose
<
float
>
(
stream
,
num_matrices
,
format
,
max_threads_per_block
,
batch_size
,
sequence_length
,
num_heads
,
qk_head_size
,
input
,
biases
,
output
,
v_head_size
);
}
}
template
<
typename
T
>
void
InvokeAddBiasTransposeTrt
(
hipStream_t
stream
,
const
int
max_threads_per_block
,
const
int
batch_size
,
const
int
sequence_length
,
const
int
num_heads
,
const
int
head_size
,
const
T
*
biases
,
const
T
*
query
,
const
T
*
key
,
const
T
*
value
,
T
*
output
)
{
constexpr
int
num_matrices
=
3
;
const
dim3
grid
(
sequence_length
,
batch_size
,
num_matrices
);
if
(
head_size
*
num_heads
<=
max_threads_per_block
)
{
const
dim3
block
(
head_size
,
num_heads
,
1
);
hipLaunchKernelGGL
(
HIP_KERNEL_NAME
(
AddBiasTransposeTrt
<
T
>
),
grid
,
block
,
0
,
stream
,
query
,
key
,
value
,
biases
,
output
);
}
else
{
const
dim3
block
(
CeilDiv
(
max_threads_per_block
,
num_heads
),
num_heads
,
1
);
hipLaunchKernelGGL
(
HIP_KERNEL_NAME
(
AddBiasTransposeTrtLarge
<
T
>
),
grid
,
block
,
0
,
stream
,
head_size
,
query
,
key
,
value
,
biases
,
output
);
}
}
template
<
>
void
LaunchAddBiasTransposeTrt
(
hipStream_t
stream
,
const
int
max_threads_per_block
,
const
int
batch_size
,
const
int
sequence_length
,
const
int
num_heads
,
const
int
head_size
,
const
float
*
biases
,
const
float
*
query
,
const
float
*
key
,
const
float
*
value
,
float
*
output
)
{
ORT_ENFORCE
(
false
,
"Shall not call this since fused kernel does not support float input."
);
}
template
<
>
void
LaunchAddBiasTransposeTrt
(
hipStream_t
stream
,
const
int
max_threads_per_block
,
const
int
batch_size
,
const
int
sequence_length
,
const
int
num_heads
,
const
int
head_size
,
const
half
*
biases
,
const
half
*
query
,
const
half
*
key
,
const
half
*
value
,
half
*
output
)
{
if
(
0
==
(
head_size
%
4
))
{
const
int
H
=
head_size
/
4
;
const
Half4
*
query2
=
reinterpret_cast
<
const
Half4
*>
(
query
);
const
Half4
*
key2
=
reinterpret_cast
<
const
Half4
*>
(
key
);
const
Half4
*
value2
=
reinterpret_cast
<
const
Half4
*>
(
value
);
const
Half4
*
biases2
=
reinterpret_cast
<
const
Half4
*>
(
biases
);
Half4
*
output2
=
reinterpret_cast
<
Half4
*>
(
output
);
InvokeAddBiasTransposeTrt
<
Half4
>
(
stream
,
max_threads_per_block
,
batch_size
,
sequence_length
,
num_heads
,
H
,
biases2
,
query2
,
key2
,
value2
,
output2
);
}
else
if
(
0
==
(
head_size
&
1
))
{
const
int
H
=
head_size
/
2
;
const
half2
*
query2
=
reinterpret_cast
<
const
half2
*>
(
query
);
const
half2
*
key2
=
reinterpret_cast
<
const
half2
*>
(
key
);
const
half2
*
value2
=
reinterpret_cast
<
const
half2
*>
(
value
);
const
half2
*
biases2
=
reinterpret_cast
<
const
half2
*>
(
biases
);
half2
*
output2
=
reinterpret_cast
<
half2
*>
(
output
);
InvokeAddBiasTransposeTrt
<
half2
>
(
stream
,
max_threads_per_block
,
batch_size
,
sequence_length
,
num_heads
,
H
,
biases2
,
query2
,
key2
,
value2
,
output2
);
}
else
{
InvokeAddBiasTransposeTrt
<
half
>
(
stream
,
max_threads_per_block
,
batch_size
,
sequence_length
,
num_heads
,
head_size
,
biases
,
query
,
key
,
value
,
output
);
}
}
}
// namespace rocm
}
// namespace contrib
}
// namespace onnxruntime
build/Linux/Release/amdgpu/onnxruntime/contrib_ops/rocm/bert/add_bias_transpose.h
0 → 100644
View file @
1a91fcc2
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include "core/providers/rocm/shared_inc/rocm_utils.h"
namespace
onnxruntime
{
namespace
contrib
{
namespace
rocm
{
// Fused kernel of Add (bias) and Transpose.
// Shape of inputs and outputs:
// biases: (num_matrices, num_heads * head_size)
// format 0:
// input: (num_matrices, batch_size, sequence_length, num_heads, head_size)
// output: (num_matrices, batch_size, num_heads, sequence_length, head_size)
// format 1:
// input : (batch_size, sequence_length, num_matrices, num_heads, head_size)
// output: (num_matrices, batch_size, num_heads, sequence_length, head_size)
// format 2:
// input : (batch_size, sequence_length, num_matrices, num_heads, head_size)
// output: (batch_size, sequence_length, num_heads, num_matrices, head_size)
template
<
typename
T
>
void
LaunchAddBiasTranspose
(
hipStream_t
stream
,
const
int
num_matrices
,
const
int
format
,
const
int
max_threads_per_block
,
const
int
batch_size
,
const
int
sequence_length
,
const
int
num_heads
,
const
int
qk_head_size
,
const
T
*
input
,
const
T
*
biases
,
T
*
output
,
bool
enable_half4
,
const
int
v_head_size
);
// Add (bias) and Transpose for separated inputs of Q, K and V, and output Trt format.
// output: (batch_size, sequence_length, num_heads, num_matrices, head_size)
// It assumes sequence_length == kv_sequence_length and head_size == v_head_size.
template
<
typename
T
>
void
LaunchAddBiasTransposeTrt
(
hipStream_t
stream
,
const
int
max_threads_per_block
,
const
int
batch_size
,
const
int
sequence_length
,
const
int
num_heads
,
const
int
head_size
,
const
T
*
biases
,
const
T
*
query
,
const
T
*
key
,
const
T
*
value
,
T
*
output
);
}
// namespace rocm
}
// namespace contrib
}
// namespace onnxruntime
build/Linux/Release/amdgpu/onnxruntime/contrib_ops/rocm/bert/attention_concat.cu
0 → 100644
View file @
1a91fcc2
#include "hip/hip_runtime.h"
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "core/providers/rocm/rocm_common.h"
#include "contrib_ops/rocm/bert/attention_impl.h"
using
namespace
onnxruntime
::
rocm
;
namespace
onnxruntime
{
namespace
contrib
{
namespace
rocm
{
template
<
typename
T
>
__global__
void
ConcatTensorToTensor
(
const
int
tensor_add_sequence_length
,
const
T
*
tensor_in
,
const
T
*
tensor_add
,
T
*
tensor_out
)
{
const
int
h
=
threadIdx
.
x
;
const
int
n
=
threadIdx
.
y
;
const
int
s
=
blockIdx
.
x
;
const
int
b
=
blockIdx
.
y
;
const
int
chunk_id
=
blockIdx
.
z
;
const
int
all_sequence_length
=
gridDim
.
x
;
const
int
batch_size
=
gridDim
.
y
;
const
int
num_heads
=
blockDim
.
y
;
const
int
H
=
blockDim
.
x
;
// K: number of identical tensors
// tensor_in: K x BxNxPxH
// tensor_add: K x BxNxLxH
// tensor_out: K x BxNxTxH, where T = P + L
const
int
tensor_in_sequence_length
=
all_sequence_length
-
tensor_add_sequence_length
;
const
int
present_SH
=
all_sequence_length
*
H
;
const
int
present_NSH
=
num_heads
*
present_SH
;
int
out_offset
=
b
*
present_NSH
+
n
*
present_SH
+
s
*
H
+
h
+
chunk_id
*
(
present_NSH
*
batch_size
);
if
(
s
<
tensor_in_sequence_length
)
{
const
int
past_SH
=
tensor_in_sequence_length
*
H
;
const
int
past_NSH
=
num_heads
*
past_SH
;
const
int
in_offset
=
b
*
past_NSH
+
n
*
past_SH
+
s
*
H
+
h
+
chunk_id
*
(
past_NSH
*
batch_size
);
tensor_out
[
out_offset
]
=
tensor_in
[
in_offset
];
}
else
if
(
s
<
all_sequence_length
)
{
const
int
SH
=
tensor_add_sequence_length
*
H
;
const
int
NSH
=
num_heads
*
SH
;
const
int
in_offset
=
b
*
NSH
+
n
*
SH
+
(
s
-
tensor_in_sequence_length
)
*
H
+
h
+
chunk_id
*
(
NSH
*
batch_size
);
tensor_out
[
out_offset
]
=
tensor_add
[
in_offset
];
}
}
template
<
typename
T
>
__global__
void
ConcatTensorToTensorLarge
(
const
int
tensor_add_sequence_length
,
const
int
H
,
const
T
*
tensor_in
,
const
T
*
tensor_add
,
T
*
tensor_out
)
{
// Use when (H*)*num_heads > 1024
int
h
=
threadIdx
.
x
;
const
int
n
=
threadIdx
.
y
;
const
int
s
=
blockIdx
.
x
;
const
int
b
=
blockIdx
.
y
;
const
int
chunk_id
=
blockIdx
.
z
;
const
int
all_sequence_length
=
gridDim
.
x
;
const
int
batch_size
=
gridDim
.
y
;
const
int
num_heads
=
blockDim
.
y
;
const
int
stride
=
blockDim
.
x
;
// K: number of identical tensor
// tensor_in: K x BxNxPxH
// tensor_add: K x BxNxLxH
// tensor_out: K x BxNxTxH
const
int
tensor_in_sequence_length
=
all_sequence_length
-
tensor_add_sequence_length
;
const
int
present_SH
=
all_sequence_length
*
H
;
const
int
present_NSH
=
num_heads
*
present_SH
;
while
(
h
<
H
)
{
int
out_offset
=
b
*
present_NSH
+
n
*
present_SH
+
s
*
H
+
h
+
chunk_id
*
(
present_NSH
*
batch_size
);
if
(
s
<
tensor_in_sequence_length
)
{
const
int
past_SH
=
tensor_in_sequence_length
*
H
;
const
int
past_NSH
=
num_heads
*
past_SH
;
const
int
in_offset
=
b
*
past_NSH
+
n
*
past_SH
+
s
*
H
+
h
+
chunk_id
*
(
past_NSH
*
batch_size
);
tensor_out
[
out_offset
]
=
tensor_in
[
in_offset
];
}
else
if
(
s
<
all_sequence_length
)
{
const
int
SH
=
tensor_add_sequence_length
*
H
;
const
int
NSH
=
num_heads
*
SH
;
const
int
in_offset
=
b
*
NSH
+
n
*
SH
+
(
s
-
tensor_in_sequence_length
)
*
H
+
h
+
chunk_id
*
(
NSH
*
batch_size
);
tensor_out
[
out_offset
]
=
tensor_add
[
in_offset
];
}
h
+=
stride
;
}
}
Status
LaunchConcatTensorToTensor
(
hipStream_t
stream
,
const
int
all_sequence_length
,
const
int
sequence_length
,
const
int
batch_size
,
const
int
head_size
,
const
int
num_heads
,
const
int
max_threads_per_block
,
const
int
matrix_num
,
const
float
*
tensor_in
,
const
float
*
tensor_add
,
float
*
tensor_out
)
{
const
dim3
grid
(
all_sequence_length
,
batch_size
,
matrix_num
);
if
(
0
==
(
head_size
&
1
))
{
const
int
H
=
head_size
/
2
;
if
(
H
*
num_heads
<=
max_threads_per_block
)
{
const
dim3
block
(
H
,
num_heads
,
1
);
hipLaunchKernelGGL
(
HIP_KERNEL_NAME
(
ConcatTensorToTensor
<
float2
>
),
grid
,
block
,
0
,
stream
,
sequence_length
,
reinterpret_cast
<
const
float2
*>
(
tensor_in
),
reinterpret_cast
<
const
float2
*>
(
tensor_add
),
reinterpret_cast
<
float2
*>
(
tensor_out
));
}
else
{
const
dim3
block
(
max_threads_per_block
/
num_heads
,
num_heads
,
1
);
hipLaunchKernelGGL
(
HIP_KERNEL_NAME
(
ConcatTensorToTensorLarge
<
float2
>
),
grid
,
block
,
0
,
stream
,
sequence_length
,
H
,
reinterpret_cast
<
const
float2
*>
(
tensor_in
),
reinterpret_cast
<
const
float2
*>
(
tensor_add
),
reinterpret_cast
<
float2
*>
(
tensor_out
));
}
}
else
{
if
(
head_size
*
num_heads
<=
max_threads_per_block
)
{
const
dim3
block
(
head_size
,
num_heads
,
1
);
hipLaunchKernelGGL
(
HIP_KERNEL_NAME
(
ConcatTensorToTensor
<
float
>
),
grid
,
block
,
0
,
stream
,
sequence_length
,
tensor_in
,
tensor_add
,
tensor_out
);
}
else
{
const
dim3
block
(
max_threads_per_block
/
num_heads
,
num_heads
,
1
);
hipLaunchKernelGGL
(
HIP_KERNEL_NAME
(
ConcatTensorToTensorLarge
<
float
>
),
grid
,
block
,
0
,
stream
,
sequence_length
,
head_size
,
tensor_in
,
tensor_add
,
tensor_out
);
}
}
return
HIP_CALL
(
hipGetLastError
());
}
Status
LaunchConcatTensorToTensor
(
hipStream_t
stream
,
const
int
all_sequence_length
,
const
int
sequence_length
,
const
int
batch_size
,
const
int
head_size
,
const
int
num_heads
,
const
int
max_threads_per_block
,
const
int
matrix_num
,
const
half
*
tensor_in
,
const
half
*
tensor_add
,
half
*
tensor_out
)
{
const
dim3
grid
(
all_sequence_length
,
batch_size
,
matrix_num
);
if
(
0
==
(
head_size
%
4
))
{
const
int
H
=
head_size
/
4
;
if
(
H
*
num_heads
<=
max_threads_per_block
)
{
const
dim3
block
(
H
,
num_heads
,
1
);
hipLaunchKernelGGL
(
HIP_KERNEL_NAME
(
ConcatTensorToTensor
<
float2
>
),
grid
,
block
,
0
,
stream
,
sequence_length
,
reinterpret_cast
<
const
float2
*>
(
tensor_in
),
reinterpret_cast
<
const
float2
*>
(
tensor_add
),
reinterpret_cast
<
float2
*>
(
tensor_out
));
}
else
{
const
dim3
block
(
max_threads_per_block
/
num_heads
,
num_heads
,
1
);
hipLaunchKernelGGL
(
HIP_KERNEL_NAME
(
ConcatTensorToTensorLarge
<
float2
>
),
grid
,
block
,
0
,
stream
,
sequence_length
,
H
,
reinterpret_cast
<
const
float2
*>
(
tensor_in
),
reinterpret_cast
<
const
float2
*>
(
tensor_add
),
reinterpret_cast
<
float2
*>
(
tensor_out
));
}
}
else
if
(
0
==
(
head_size
&
1
))
{
const
int
H
=
head_size
/
2
;
if
(
H
*
num_heads
<=
max_threads_per_block
)
{
const
dim3
block
(
H
,
num_heads
,
1
);
hipLaunchKernelGGL
(
HIP_KERNEL_NAME
(
ConcatTensorToTensor
<
half2
>
),
grid
,
block
,
0
,
stream
,
sequence_length
,
reinterpret_cast
<
const
half2
*>
(
tensor_in
),
reinterpret_cast
<
const
half2
*>
(
tensor_add
),
reinterpret_cast
<
half2
*>
(
tensor_out
));
}
else
{
const
dim3
block
(
max_threads_per_block
/
num_heads
,
num_heads
,
1
);
hipLaunchKernelGGL
(
HIP_KERNEL_NAME
(
ConcatTensorToTensorLarge
<
half2
>
),
grid
,
block
,
0
,
stream
,
sequence_length
,
H
,
reinterpret_cast
<
const
half2
*>
(
tensor_in
),
reinterpret_cast
<
const
half2
*>
(
tensor_add
),
reinterpret_cast
<
half2
*>
(
tensor_out
));
}
}
else
{
// this should be an "odd" case. probably not worth catching it in the half2 kernel.
if
(
head_size
*
num_heads
<=
max_threads_per_block
)
{
const
dim3
block
(
head_size
,
num_heads
,
1
);
hipLaunchKernelGGL
(
HIP_KERNEL_NAME
(
ConcatTensorToTensor
<
half
>
),
grid
,
block
,
0
,
stream
,
sequence_length
,
tensor_in
,
tensor_add
,
tensor_out
);
}
else
{
const
dim3
block
(
max_threads_per_block
/
num_heads
,
num_heads
,
1
);
hipLaunchKernelGGL
(
HIP_KERNEL_NAME
(
ConcatTensorToTensorLarge
<
half
>
),
grid
,
block
,
0
,
stream
,
sequence_length
,
head_size
,
tensor_in
,
tensor_add
,
tensor_out
);
}
}
return
HIP_CALL
(
hipGetLastError
());
}
Status
LaunchConcatPastToPresent
(
hipStream_t
stream
,
const
int
all_sequence_length
,
const
int
sequence_length
,
const
int
batch_size
,
const
int
head_size
,
const
int
num_heads
,
const
int
max_threads_per_block
,
const
float
*
past
,
const
float
*
k_v
,
float
*
present
)
{
return
LaunchConcatTensorToTensor
(
stream
,
all_sequence_length
,
sequence_length
,
batch_size
,
head_size
,
num_heads
,
max_threads_per_block
,
2
,
past
,
k_v
,
present
);
}
Status
LaunchConcatPastToPresent
(
hipStream_t
stream
,
const
int
all_sequence_length
,
const
int
sequence_length
,
const
int
batch_size
,
const
int
head_size
,
const
int
num_heads
,
const
int
max_threads_per_block
,
const
half
*
past
,
const
half
*
k_v
,
half
*
present
)
{
return
LaunchConcatTensorToTensor
(
stream
,
all_sequence_length
,
sequence_length
,
batch_size
,
head_size
,
num_heads
,
max_threads_per_block
,
2
,
past
,
k_v
,
present
);
}
}
// namespace rocm
}
// namespace contrib
}
// namespace onnxruntime
build/Linux/Release/amdgpu/onnxruntime/contrib_ops/rocm/bert/attention_impl.h
0 → 100644
View file @
1a91fcc2
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include "core/providers/rocm/shared_inc/rocm_utils.h"
#include <hip/hip_fp16.h>
#include <rocblas.h>
#include "contrib_ops/cpu/bert/attention_common.h"
namespace
onnxruntime
{
namespace
contrib
{
namespace
rocm
{
size_t
GetAttentionScratchSize
(
size_t
element_size
,
size_t
batch_size
,
size_t
num_heads
,
size_t
sequence_length
,
size_t
all_sequence_length
);
size_t
GetAttentionWorkspaceSize
(
size_t
element_size
,
size_t
batchsize
,
size_t
num_heads
,
size_t
qk_head_size
,
size_t
v_head_size
,
size_t
sequence_length
,
size_t
kv_sequence_length
,
size_t
total_sequence_length
,
void
*
fused_runner
);
template
<
typename
T
>
struct
AttentionData
{
const
T
*
gemm_buffer
;
const
T
*
bias
;
const
T
*
query
;
const
T
*
key
;
const
T
*
value
;
const
int
*
mask_index
;
gsl
::
span
<
const
int64_t
>
mask_index_dims
;
const
T
*
past
;
const
T
*
extra_add_qk
;
T
*
workspace
;
T
*
output
;
T
*
present
;
};
template
<
typename
T
>
Status
QkvToContext
(
const
hipDeviceProp_t
&
prop
,
rocblas_handle
&
rocblas
,
hipStream_t
stream
,
contrib
::
AttentionParameters
&
parameters
,
AttentionData
<
T
>&
data
,
void
*
fused_runner
);
Status
LaunchDecoderAttentionKernel
(
const
hipDeviceProp_t
&
prop
,
// Device Properties
hipStream_t
stream
,
// Cuda stream
rocblas_handle
&
rocblas
,
// Rocblas handle
const
size_t
element_size
,
// Element size of input tensor
const
int
batch_size
,
// Batch size (B)
const
int
sequence_length
,
// Sequence length (S)
const
int
kv_sequence_length
,
// Key/Value/Cache sequence length
const
int
num_heads
,
// Number of attention heads (N)
const
int
head_size
,
// Hidden size per head (H)
const
bool
static_kv
,
// Whether cross attention or not
const
bool
use_past
,
// Whether use cache or not
const
bool
has_layer_state
,
// Whether output cache or not
const
bool
has_key_padding_mask
,
// Whether use key_padding_mask or not
const
void
*
gemm_query_buffer
,
// Query buffer
const
void
*
gemm_kv_buffer
,
// Key and value buffer
const
bool
*
key_padding_mask
,
// Key padding mask
const
void
*
key_cache
,
// Input key cache
const
void
*
value_cache
,
// Input value cache
void
*
qkv_buffer
,
// Temporary buffer
void
*
workspace_buffer
,
// Temporary buffer
void
*
output
,
// Output tensor
void
*
new_key_cache
,
// New_key_cache tensor
void
*
new_value_cache
// New_value_cache tensor
);
// BxNxSxH => BxSxNxH or SxBxNxH (reversed_bs is true)
Status
LaunchTransCtx
(
hipStream_t
stream
,
const
int
sequence_length
,
const
int
batch_size
,
const
int
head_size
,
const
int
num_heads
,
const
int
max_threads_per_block
,
const
bool
reversed_bs
,
const
float
*
input
,
float
*
output
);
Status
LaunchTransCtx
(
hipStream_t
stream
,
const
int
sequence_length
,
const
int
batch_size
,
const
int
head_size
,
const
int
num_heads
,
const
int
max_threads_per_block
,
const
bool
reversed_bs
,
const
half
*
input
,
half
*
output
);
// BxSxMxNxH or SxBxMxNxH (reversed_bs is true) => MxBxNxSxH
Status
LaunchTransQkv
(
hipStream_t
stream
,
const
int
matrix_num
,
const
int
sequence_length
,
const
int
batch_size
,
const
int
head_size
,
const
int
num_heads
,
const
int
max_threads_per_block
,
const
bool
reversed_bs
,
const
float
*
input
,
float
*
output
);
Status
LaunchTransQkv
(
hipStream_t
stream
,
const
int
matrix_num
,
const
int
sequence_length
,
const
int
batch_size
,
const
int
head_size
,
const
int
num_heads
,
const
int
max_threads_per_block
,
const
bool
reversed_bs
,
const
half
*
input
,
half
*
output
);
Status
LaunchConcatTensorToTensor
(
hipStream_t
stream
,
const
int
all_sequence_length
,
const
int
sequence_length
,
const
int
batch_size
,
const
int
head_size
,
const
int
num_heads
,
const
int
max_threads_per_block
,
const
int
matrix_num
,
const
float
*
tensor_in
,
const
float
*
tensor_add
,
float
*
tensor_out
);
Status
LaunchConcatTensorToTensor
(
hipStream_t
stream
,
const
int
all_sequence_length
,
const
int
sequence_length
,
const
int
batch_size
,
const
int
head_size
,
const
int
num_heads
,
const
int
max_threads_per_block
,
const
int
matrix_num
,
const
half
*
tensor_in
,
const
half
*
tensor_add
,
half
*
tensor_out
);
Status
LaunchConcatPastToPresent
(
hipStream_t
stream
,
const
int
all_sequence_length
,
const
int
sequence_length
,
const
int
batch_size
,
const
int
head_size
,
const
int
num_heads
,
const
int
max_threads_per_block
,
const
float
*
past
,
const
float
*
k_v
,
float
*
present
);
Status
LaunchConcatPastToPresent
(
hipStream_t
stream
,
const
int
all_sequence_length
,
const
int
sequence_length
,
const
int
batch_size
,
const
int
head_size
,
const
int
num_heads
,
const
int
max_threads_per_block
,
const
half
*
past
,
const
half
*
k_v
,
half
*
present
);
}
// namespace rocm
}
// namespace contrib
}
// namespace onnxruntime
build/Linux/Release/amdgpu/onnxruntime/contrib_ops/rocm/bert/attention_transpose.cu
0 → 100644
View file @
1a91fcc2
#include "hip/hip_runtime.h"
/*
The implementation of this file is based on qkvToContext plugin in TensorRT demo:
https://github.com/NVIDIA/TensorRT/tree/release/5.1/demo/BERT/
Copyright 2019 NVIDIA Corporation
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
// Modifications: add transpose kernels for TRT format
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "core/providers/rocm/rocm_common.h"
#include "contrib_ops/rocm/bert/attention_impl.h"
using
namespace
onnxruntime
::
rocm
;
namespace
onnxruntime
{
namespace
contrib
{
namespace
rocm
{
template
<
typename
T
>
__global__
void
TransposeCtx
(
const
int
H
,
const
bool
reversed_bs
,
const
T
*
input
,
T
*
output
)
{
// Input: BxNxSxH
// Output: BxSxNxH
int
n
=
threadIdx
.
y
;
int
s
=
blockIdx
.
x
;
int
b
=
blockIdx
.
y
;
int
num_heads
=
blockDim
.
y
;
int
sequence_length
=
gridDim
.
x
;
const
int
NH
=
num_heads
*
H
;
const
int
NHS
=
NH
*
sequence_length
;
const
int
in_offset
=
s
*
H
+
n
*
sequence_length
*
H
+
b
*
NHS
;
int
out_offset
=
0
;
if
(
reversed_bs
)
{
const
int
batch_size
=
gridDim
.
y
;
const
int
BNH
=
NH
*
batch_size
;
out_offset
=
n
*
H
+
b
*
NH
+
s
*
BNH
;
}
else
{
out_offset
=
n
*
H
+
s
*
NH
+
b
*
NHS
;
}
const
int
i
=
threadIdx
.
x
;
if
(
i
<
H
)
{
output
[
out_offset
+
i
]
=
input
[
in_offset
+
i
];
}
}
template
<
typename
T
>
__global__
void
TransposeCtxLarge
(
const
int
H
,
const
bool
reversed_bs
,
const
T
*
input
,
T
*
output
)
{
// Use when (H*)*num_heads > 1024
// Input: BxNxSxH
// Output: BxSxNxH
int
n
=
threadIdx
.
y
;
int
s
=
blockIdx
.
x
;
int
b
=
blockIdx
.
y
;
int
stride
=
blockDim
.
x
;
int
num_heads
=
blockDim
.
y
;
int
sequence_length
=
gridDim
.
x
;
const
int
NH
=
num_heads
*
H
;
const
int
NHS
=
NH
*
sequence_length
;
const
int
in_offset
=
s
*
H
+
n
*
sequence_length
*
H
+
b
*
NHS
;
int
out_offset
=
0
;
if
(
reversed_bs
)
{
const
int
batch_size
=
gridDim
.
y
;
const
int
BNH
=
NH
*
batch_size
;
out_offset
=
n
*
H
+
b
*
NH
+
s
*
BNH
;
}
else
{
out_offset
=
n
*
H
+
s
*
NH
+
b
*
NHS
;
}
int
i
=
threadIdx
.
x
;
while
(
i
<
H
)
{
output
[
out_offset
+
i
]
=
input
[
in_offset
+
i
];
i
+=
stride
;
}
}
Status
LaunchTransCtx
(
hipStream_t
stream
,
const
int
sequence_length
,
const
int
batch_size
,
const
int
head_size
,
const
int
num_heads
,
const
int
max_threads_per_block
,
const
bool
reversed_bs
,
const
float
*
input
,
float
*
output
)
{
const
dim3
grid
(
sequence_length
,
batch_size
,
1
);
if
(
0
==
(
head_size
&
1
))
{
const
int
H
=
head_size
/
2
;
const
float2
*
input2
=
reinterpret_cast
<
const
float2
*>
(
input
);
float2
*
output2
=
reinterpret_cast
<
float2
*>
(
output
);
if
(
H
*
num_heads
<=
max_threads_per_block
)
{
const
dim3
block
(
H
,
num_heads
,
1
);
hipLaunchKernelGGL
(
HIP_KERNEL_NAME
(
TransposeCtx
<
float2
>
),
grid
,
block
,
0
,
stream
,
H
,
reversed_bs
,
input2
,
output2
);
}
else
{
const
dim3
block
(
max_threads_per_block
/
num_heads
,
num_heads
,
1
);
hipLaunchKernelGGL
(
HIP_KERNEL_NAME
(
TransposeCtxLarge
<
float2
>
),
grid
,
block
,
0
,
stream
,
H
,
reversed_bs
,
input2
,
output2
);
}
}
else
{
if
(
head_size
*
num_heads
<=
max_threads_per_block
)
{
const
dim3
block
(
head_size
,
num_heads
,
1
);
hipLaunchKernelGGL
(
HIP_KERNEL_NAME
(
TransposeCtx
<
float
>
),
grid
,
block
,
0
,
stream
,
head_size
,
reversed_bs
,
input
,
output
);
}
else
{
const
dim3
block
(
max_threads_per_block
/
num_heads
,
num_heads
,
1
);
hipLaunchKernelGGL
(
HIP_KERNEL_NAME
(
TransposeCtxLarge
<
float
>
),
grid
,
block
,
0
,
stream
,
head_size
,
reversed_bs
,
input
,
output
);
}
}
return
HIP_CALL
(
hipGetLastError
());
}
Status
LaunchTransCtx
(
hipStream_t
stream
,
const
int
sequence_length
,
const
int
batch_size
,
const
int
head_size
,
const
int
num_heads
,
const
int
max_threads_per_block
,
const
bool
reversed_bs
,
const
half
*
input
,
half
*
output
)
{
const
dim3
grid
(
sequence_length
,
batch_size
,
1
);
if
(
0
==
(
head_size
%
4
))
{
const
int
H
=
head_size
/
4
;
const
float2
*
input2
=
reinterpret_cast
<
const
float2
*>
(
input
);
float2
*
output2
=
reinterpret_cast
<
float2
*>
(
output
);
if
(
H
*
num_heads
<=
max_threads_per_block
)
{
const
dim3
block
(
H
,
num_heads
,
1
);
hipLaunchKernelGGL
(
HIP_KERNEL_NAME
(
TransposeCtx
<
float2
>
),
grid
,
block
,
0
,
stream
,
H
,
reversed_bs
,
input2
,
output2
);
}
else
{
const
dim3
block
(
max_threads_per_block
/
num_heads
,
num_heads
,
1
);
hipLaunchKernelGGL
(
HIP_KERNEL_NAME
(
TransposeCtxLarge
<
float2
>
),
grid
,
block
,
0
,
stream
,
H
,
reversed_bs
,
input2
,
output2
);
}
}
else
if
(
0
==
(
head_size
&
1
))
{
const
int
H
=
head_size
/
2
;
const
half2
*
input2
=
reinterpret_cast
<
const
half2
*>
(
input
);
half2
*
output2
=
reinterpret_cast
<
half2
*>
(
output
);
if
(
H
*
num_heads
<=
max_threads_per_block
)
{
const
dim3
block
(
H
,
num_heads
,
1
);
hipLaunchKernelGGL
(
HIP_KERNEL_NAME
(
TransposeCtx
<
half2
>
),
grid
,
block
,
0
,
stream
,
H
,
reversed_bs
,
input2
,
output2
);
}
else
{
const
dim3
block
(
max_threads_per_block
/
num_heads
,
num_heads
,
1
);
hipLaunchKernelGGL
(
HIP_KERNEL_NAME
(
TransposeCtxLarge
<
half2
>
),
grid
,
block
,
0
,
stream
,
H
,
reversed_bs
,
input2
,
output2
);
}
}
else
{
// this should be an "odd" case. probably not worth catching it in the half2 kernel.
if
(
head_size
*
num_heads
<=
max_threads_per_block
)
{
const
dim3
block
(
head_size
,
num_heads
,
1
);
hipLaunchKernelGGL
(
HIP_KERNEL_NAME
(
TransposeCtx
<
half
>
),
grid
,
block
,
0
,
stream
,
head_size
,
reversed_bs
,
input
,
output
);
}
else
{
const
dim3
block
(
max_threads_per_block
/
num_heads
,
num_heads
,
1
);
hipLaunchKernelGGL
(
HIP_KERNEL_NAME
(
TransposeCtxLarge
<
half
>
),
grid
,
block
,
0
,
stream
,
head_size
,
reversed_bs
,
input
,
output
);
}
}
return
HIP_CALL
(
hipGetLastError
());
}
template
<
typename
T
>
__global__
void
TransposeQKV
(
const
int
H
,
const
bool
reversed_bs
,
const
T
*
input
,
T
*
output
)
{
// Input: BxSxKxNxH or SxBxKxNxH
// Output: KxBxNxSxH
// K is the number of identical matrix
int
n
=
threadIdx
.
y
;
int
s
=
blockIdx
.
x
;
int
b
=
blockIdx
.
y
;
int
m
=
blockIdx
.
z
;
// matrix id
const
int
num_heads
=
blockDim
.
y
;
const
int
sequence_length
=
gridDim
.
x
;
const
int
batch_size
=
gridDim
.
y
;
const
int
chunk_num
=
gridDim
.
z
;
const
int
NH
=
num_heads
*
H
;
const
int
NHS
=
NH
*
sequence_length
;
int
in_offset
=
0
;
if
(
reversed_bs
)
{
const
int
BNH
=
NH
*
batch_size
;
in_offset
=
n
*
H
+
(
m
+
b
*
chunk_num
)
*
NH
+
s
*
BNH
*
chunk_num
;
}
else
{
in_offset
=
n
*
H
+
(
m
+
s
*
chunk_num
)
*
NH
+
b
*
NHS
*
chunk_num
;
}
const
int
out_offset
=
s
*
H
+
n
*
sequence_length
*
H
+
b
*
NHS
+
m
*
NHS
*
batch_size
;
const
int
i
=
threadIdx
.
x
;
if
(
i
<
H
)
{
output
[
out_offset
+
i
]
=
input
[
in_offset
+
i
];
}
}
template
<
typename
T
>
__global__
void
TransposeQKVLarge
(
const
int
H
,
const
bool
reversed_bs
,
const
T
*
input
,
T
*
output
)
{
// Use when (H*)*num_heads > 1024
// Input: BxSxKxNxH or SxBxKxNxH
// Output: KxBxNxSxH
// K is the number of identical matrix
int
n
=
threadIdx
.
y
;
int
s
=
blockIdx
.
x
;
int
b
=
blockIdx
.
y
;
int
m
=
blockIdx
.
z
;
// matrix id
const
int
stride
=
blockDim
.
x
;
const
int
num_heads
=
blockDim
.
y
;
const
int
sequence_length
=
gridDim
.
x
;
const
int
batch_size
=
gridDim
.
y
;
const
int
chunk_num
=
gridDim
.
z
;
const
int
NH
=
num_heads
*
H
;
const
int
NHS
=
NH
*
sequence_length
;
int
in_offset
=
0
;
if
(
reversed_bs
)
{
const
int
BNH
=
NH
*
batch_size
;
in_offset
=
n
*
H
+
(
m
+
b
*
chunk_num
)
*
NH
+
s
*
BNH
*
chunk_num
;
}
else
{
in_offset
=
n
*
H
+
(
m
+
s
*
chunk_num
)
*
NH
+
b
*
NHS
*
chunk_num
;
}
const
int
out_offset
=
s
*
H
+
n
*
sequence_length
*
H
+
b
*
NHS
+
m
*
NHS
*
batch_size
;
int
i
=
threadIdx
.
x
;
while
(
i
<
H
)
{
output
[
out_offset
+
i
]
=
input
[
in_offset
+
i
];
i
+=
stride
;
}
}
Status
LaunchTransQkv
(
hipStream_t
stream
,
const
int
matrix_num
,
const
int
sequence_length
,
const
int
batch_size
,
const
int
head_size
,
const
int
num_heads
,
const
int
max_threads_per_block
,
const
bool
reversed_bs
,
const
float
*
input
,
float
*
output
)
{
const
dim3
grid
(
sequence_length
,
batch_size
,
matrix_num
);
if
(
0
==
(
head_size
&
1
))
{
const
int
H
=
head_size
/
2
;
const
float2
*
input2
=
reinterpret_cast
<
const
float2
*>
(
input
);
float2
*
output2
=
reinterpret_cast
<
float2
*>
(
output
);
if
(
H
*
num_heads
<=
max_threads_per_block
)
{
const
dim3
block
(
H
,
num_heads
,
1
);
hipLaunchKernelGGL
(
HIP_KERNEL_NAME
(
TransposeQKV
<
float2
>
),
grid
,
block
,
0
,
stream
,
H
,
reversed_bs
,
input2
,
output2
);
}
else
{
const
dim3
block
(
max_threads_per_block
/
num_heads
,
num_heads
,
1
);
hipLaunchKernelGGL
(
HIP_KERNEL_NAME
(
TransposeQKVLarge
<
float2
>
),
grid
,
block
,
0
,
stream
,
H
,
reversed_bs
,
input2
,
output2
);
}
}
else
{
if
(
head_size
*
num_heads
<=
max_threads_per_block
)
{
const
dim3
block
(
head_size
,
num_heads
,
1
);
hipLaunchKernelGGL
(
HIP_KERNEL_NAME
(
TransposeQKV
<
float
>
),
grid
,
block
,
0
,
stream
,
head_size
,
reversed_bs
,
input
,
output
);
}
else
{
const
dim3
block
(
max_threads_per_block
/
num_heads
,
num_heads
,
1
);
hipLaunchKernelGGL
(
HIP_KERNEL_NAME
(
TransposeQKVLarge
<
float
>
),
grid
,
block
,
0
,
stream
,
head_size
,
reversed_bs
,
input
,
output
);
}
}
return
HIP_CALL
(
hipGetLastError
());
}
Status
LaunchTransQkv
(
hipStream_t
stream
,
const
int
matrix_num
,
const
int
sequence_length
,
const
int
batch_size
,
const
int
head_size
,
const
int
num_heads
,
const
int
max_threads_per_block
,
const
bool
reversed_bs
,
const
half
*
input
,
half
*
output
)
{
const
dim3
grid
(
sequence_length
,
batch_size
,
matrix_num
);
if
(
0
==
(
head_size
%
4
))
{
const
int
H
=
head_size
/
4
;
const
float2
*
input2
=
reinterpret_cast
<
const
float2
*>
(
input
);
float2
*
output2
=
reinterpret_cast
<
float2
*>
(
output
);
if
(
H
*
num_heads
<=
max_threads_per_block
)
{
const
dim3
block
(
H
,
num_heads
,
1
);
hipLaunchKernelGGL
(
HIP_KERNEL_NAME
(
TransposeQKV
<
float2
>
),
grid
,
block
,
0
,
stream
,
H
,
reversed_bs
,
input2
,
output2
);
}
else
{
const
dim3
block
(
max_threads_per_block
/
num_heads
,
num_heads
,
1
);
hipLaunchKernelGGL
(
HIP_KERNEL_NAME
(
TransposeQKVLarge
<
float2
>
),
grid
,
block
,
0
,
stream
,
H
,
reversed_bs
,
input2
,
output2
);
}
}
else
if
(
0
==
(
head_size
&
1
))
{
const
int
H
=
head_size
/
2
;
const
half2
*
input2
=
reinterpret_cast
<
const
half2
*>
(
input
);
half2
*
output2
=
reinterpret_cast
<
half2
*>
(
output
);
if
(
H
*
num_heads
<=
max_threads_per_block
)
{
const
dim3
block
(
H
,
num_heads
,
1
);
hipLaunchKernelGGL
(
HIP_KERNEL_NAME
(
TransposeQKV
<
half2
>
),
grid
,
block
,
0
,
stream
,
H
,
reversed_bs
,
input2
,
output2
);
}
else
{
const
dim3
block
(
max_threads_per_block
/
num_heads
,
num_heads
,
1
);
hipLaunchKernelGGL
(
HIP_KERNEL_NAME
(
TransposeQKVLarge
<
half2
>
),
grid
,
block
,
0
,
stream
,
H
,
reversed_bs
,
input2
,
output2
);
}
}
else
{
// this should be an "odd" case. probably not worth catching it in the half2 kernel..
if
(
head_size
*
num_heads
<=
max_threads_per_block
)
{
const
dim3
block
(
head_size
,
num_heads
,
1
);
hipLaunchKernelGGL
(
HIP_KERNEL_NAME
(
TransposeQKV
<
half
>
),
grid
,
block
,
0
,
stream
,
head_size
,
reversed_bs
,
input
,
output
);
}
else
{
const
dim3
block
(
max_threads_per_block
/
num_heads
,
num_heads
,
1
);
hipLaunchKernelGGL
(
HIP_KERNEL_NAME
(
TransposeQKVLarge
<
half
>
),
grid
,
block
,
0
,
stream
,
head_size
,
reversed_bs
,
input
,
output
);
}
}
return
HIP_CALL
(
hipGetLastError
());
}
}
// namespace rocm
}
// namespace contrib
}
// namespace onnxruntime
build/Linux/Release/amdgpu/onnxruntime/contrib_ops/rocm/bert/bert_padding.cu
0 → 100644
View file @
1a91fcc2
#include "hip/hip_runtime.h"
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
// TrtSequenceOffset kernels are modified from FasterTransformer
/*
* Copyright (c) 2019-2021, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "core/providers/rocm/rocm_common.h"
#include "contrib_ops/rocm/bert/bert_padding.h"
using
namespace
onnxruntime
::
rocm
;
namespace
onnxruntime
{
namespace
contrib
{
namespace
rocm
{
constexpr
int32_t
kMAX_THREADS_PER_BLOCK
=
256
;
// -----------------------------------
// Get indices of non-padding tokens and padding tokens. Here we assume that padding is on the right side of sequence.
// sequence_token_count is number of non-padding tokens per sequence, and it has shape [batch_size].
// For example, we have 3 sequences with 1, 2, 4 non-padding tokens and positions like the following (* means padding):
// Sequence_0: 0, 1*, 2*, 3*
// Sequence_1: 4, 5, 6*, 7*
// Sequence_2: 8, 9, 10, 11
// token_offset: 0, 4, 5, 8, 9, 10, 11, 1*, 2*, 3*, 6*, 7*
// token_count_buffer has two numbers for non-padding tokens:
// total_token_count: 1 + 2 + 4 = 7
// max_token_count: 4
// cumulated_token_count: 0, 1, 1+2, 1+2+4
__global__
void
getTokenOffset
(
int
*
token_count_buffer
,
int
*
token_offset
,
int
*
cumulated_token_count
,
const
int
*
sequence_token_count
,
const
int
batch_size
,
const
int
sequence_length
)
{
// Find offset of non-padding tokens, and max sequence length among all batches
// TODO(tianleiwu): Use hipcub::DevicePartition::Flagged like BuildGlobalIndex in longformer_global_impl.cu
// to build token_offset when sequence length is large.
int
total_tokens
=
0
;
int
max_tokens
=
0
;
int
index
=
0
;
cumulated_token_count
[
0
]
=
0
;
for
(
int
i
=
0
;
i
<
batch_size
;
i
++
)
{
const
int
count
=
sequence_token_count
[
i
];
if
(
count
>
max_tokens
)
{
max_tokens
=
count
;
}
cumulated_token_count
[
i
+
1
]
=
cumulated_token_count
[
i
]
+
count
;
for
(
int
j
=
0
;
j
<
count
;
j
++
)
{
token_offset
[
index
]
=
i
*
sequence_length
+
j
;
index
++
;
}
total_tokens
+=
count
;
}
// Offset of paddings
for
(
int
i
=
0
;
i
<
batch_size
;
i
++
)
{
const
int
count
=
sequence_token_count
[
i
];
for
(
int
j
=
0
;
j
<
sequence_length
-
count
;
j
++
)
{
token_offset
[
index
]
=
i
*
sequence_length
+
count
+
j
;
index
++
;
}
}
token_count_buffer
[
0
]
=
total_tokens
;
token_count_buffer
[
1
]
=
max_tokens
;
}
void
LaunchGetTokenOffset
(
int
*
token_count_buffer
,
int
*
token_offset
,
int
*
cumulated_token_count
,
const
int
*
sequence_token_count
,
const
int
batch_size
,
const
int
sequence_length
,
hipStream_t
stream
)
{
hipLaunchKernelGGL
(
getTokenOffset
,
1
,
1
,
0
,
stream
,
token_count_buffer
,
token_offset
,
cumulated_token_count
,
sequence_token_count
,
batch_size
,
sequence_length
);
}
// -----------------------------------
// Remove paddings
template
<
typename
T
>
__global__
void
__launch_bounds__
(
kMAX_THREADS_PER_BLOCK
)
removePadding
(
T
*
target
,
const
T
*
source
,
const
int
*
token_offset
,
const
int
width
)
{
const
int
tid
=
threadIdx
.
x
;
const
int
token_index
=
blockIdx
.
x
;
const
int
source_offset
=
token_offset
[
token_index
];
const
int
target_offset
=
token_index
;
for
(
int
i
=
tid
;
i
<
width
;
i
+=
blockDim
.
x
)
{
target
[
target_offset
*
width
+
i
]
=
source
[
source_offset
*
width
+
i
];
}
}
template
<
>
void
LaunchRemovePadding
(
half
*
output
,
const
half
*
input
,
const
int
*
token_offset
,
const
int
token_count
,
const
int
hidden_size
,
hipStream_t
stream
)
{
// input: [batch_size, sequence_length, hidden_size]
// output: [token_count, hidden_size]
// Make sure memory is aligned to 128 bit
ORT_ENFORCE
(
!
(
reinterpret_cast
<
size_t
>
(
input
)
&
0xF
)
&&
!
(
reinterpret_cast
<
size_t
>
(
output
)
&
0xF
),
"alignment"
);
if
(
hidden_size
%
8
==
0
)
{
const
int
width
=
hidden_size
/
8
;
const
int4
*
input2
=
reinterpret_cast
<
const
int4
*>
(
input
);
int4
*
output2
=
reinterpret_cast
<
int4
*>
(
output
);
hipLaunchKernelGGL
(
HIP_KERNEL_NAME
(
removePadding
<
int4
>
),
token_count
,
kMAX_THREADS_PER_BLOCK
,
0
,
stream
,
output2
,
input2
,
token_offset
,
width
);
}
else
if
(
hidden_size
%
4
==
0
)
{
const
int
width
=
hidden_size
/
4
;
const
int64_t
*
input2
=
reinterpret_cast
<
const
int64_t
*>
(
input
);
int64_t
*
output2
=
reinterpret_cast
<
int64_t
*>
(
output
);
hipLaunchKernelGGL
(
HIP_KERNEL_NAME
(
removePadding
<
int64_t
>
),
token_count
,
kMAX_THREADS_PER_BLOCK
,
0
,
stream
,
output2
,
input2
,
token_offset
,
width
);
}
else
if
(
hidden_size
%
2
==
0
)
{
const
int
width
=
hidden_size
/
2
;
const
int32_t
*
input2
=
reinterpret_cast
<
const
int32_t
*>
(
input
);
int32_t
*
output2
=
reinterpret_cast
<
int32_t
*>
(
output
);
hipLaunchKernelGGL
(
HIP_KERNEL_NAME
(
removePadding
<
int32_t
>
),
token_count
,
kMAX_THREADS_PER_BLOCK
,
0
,
stream
,
output2
,
input2
,
token_offset
,
width
);
}
else
{
const
int
width
=
hidden_size
;
const
int16_t
*
input2
=
reinterpret_cast
<
const
int16_t
*>
(
input
);
int16_t
*
output2
=
reinterpret_cast
<
int16_t
*>
(
output
);
hipLaunchKernelGGL
(
HIP_KERNEL_NAME
(
removePadding
<
int16_t
>
),
token_count
,
kMAX_THREADS_PER_BLOCK
,
0
,
stream
,
output2
,
input2
,
token_offset
,
width
);
}
}
template
<
>
void
LaunchRemovePadding
(
float
*
output
,
const
float
*
input
,
const
int
*
token_offset
,
const
int
token_count
,
const
int
hidden_size
,
hipStream_t
stream
)
{
ORT_ENFORCE
(
!
(
reinterpret_cast
<
size_t
>
(
input
)
&
0xF
)
&&
!
(
reinterpret_cast
<
size_t
>
(
output
)
&
0xF
),
"alignment"
);
if
(
hidden_size
%
4
==
0
)
{
const
int
width
=
hidden_size
/
4
;
const
int4
*
input2
=
reinterpret_cast
<
const
int4
*>
(
input
);
int4
*
output2
=
reinterpret_cast
<
int4
*>
(
output
);
hipLaunchKernelGGL
(
HIP_KERNEL_NAME
(
removePadding
<
int4
>
),
token_count
,
kMAX_THREADS_PER_BLOCK
,
0
,
stream
,
output2
,
input2
,
token_offset
,
width
);
}
else
if
(
hidden_size
%
2
==
0
)
{
const
int
width
=
hidden_size
/
2
;
const
int64_t
*
input2
=
reinterpret_cast
<
const
int64_t
*>
(
input
);
int64_t
*
output2
=
reinterpret_cast
<
int64_t
*>
(
output
);
hipLaunchKernelGGL
(
HIP_KERNEL_NAME
(
removePadding
<
int64_t
>
),
token_count
,
kMAX_THREADS_PER_BLOCK
,
0
,
stream
,
output2
,
input2
,
token_offset
,
width
);
}
else
{
const
int
width
=
hidden_size
;
const
int32_t
*
input2
=
reinterpret_cast
<
const
int32_t
*>
(
input
);
int32_t
*
output2
=
reinterpret_cast
<
int32_t
*>
(
output
);
hipLaunchKernelGGL
(
HIP_KERNEL_NAME
(
removePadding
<
int32_t
>
),
token_count
,
kMAX_THREADS_PER_BLOCK
,
0
,
stream
,
output2
,
input2
,
token_offset
,
width
);
}
}
// -----------------------------------
// Recover padding.
template
<
typename
T
>
__global__
void
__launch_bounds__
(
kMAX_THREADS_PER_BLOCK
)
restorePadding
(
T
*
target
,
const
T
*
source
,
const
int
*
token_offset
,
const
int
width
,
const
int
token_count
)
{
const
int
tid
=
threadIdx
.
x
;
const
int
token_index
=
blockIdx
.
x
;
const
int
target_seq_id
=
token_offset
[
token_index
];
const
int
source_seq_id
=
token_index
;
constexpr
T
padding_zero
=
0
;
if
(
token_index
<
token_count
)
{
for
(
int
i
=
tid
;
i
<
width
;
i
+=
blockDim
.
x
)
{
target
[
target_seq_id
*
width
+
i
]
=
source
[
source_seq_id
*
width
+
i
];
}
}
else
{
// It is padding: fill with zeros
for
(
int
i
=
tid
;
i
<
width
;
i
+=
blockDim
.
x
)
{
target
[
target_seq_id
*
width
+
i
]
=
padding_zero
;
}
}
}
template
<
>
__global__
void
__launch_bounds__
(
kMAX_THREADS_PER_BLOCK
)
restorePadding
(
int4
*
target
,
const
int4
*
source
,
const
int
*
token_offset
,
const
int
width
,
const
int
token_count
)
{
const
int
tid
=
threadIdx
.
x
;
const
int
token_index
=
blockIdx
.
x
;
const
int
target_seq_id
=
token_offset
[
token_index
];
const
int
source_seq_id
=
token_index
;
int4
padding_zero
{
0
,
0
,
0
,
0
};
if
(
token_index
<
token_count
)
{
for
(
int
i
=
tid
;
i
<
width
;
i
+=
blockDim
.
x
)
{
target
[
target_seq_id
*
width
+
i
]
=
source
[
source_seq_id
*
width
+
i
];
}
}
else
{
// It is padding: fill with zeros
for
(
int
i
=
tid
;
i
<
width
;
i
+=
blockDim
.
x
)
{
target
[
target_seq_id
*
width
+
i
]
=
padding_zero
;
}
}
}
template
<
>
void
LaunchRestorePadding
(
float
*
output
,
const
float
*
input
,
const
int
*
token_offset
,
const
int
token_count
,
const
int
hidden_size
,
const
int
batch_size
,
const
int
sequence_length
,
hipStream_t
stream
)
{
ORT_ENFORCE
(
!
(
reinterpret_cast
<
size_t
>
(
input
)
&
0xF
)
&&
!
(
reinterpret_cast
<
size_t
>
(
output
)
&
0xF
),
"alignment"
);
int
grid_size
=
batch_size
*
sequence_length
;
if
(
hidden_size
%
4
==
0
)
{
const
int
width
=
hidden_size
/
4
;
const
int4
*
input2
=
reinterpret_cast
<
const
int4
*>
(
input
);
int4
*
output2
=
reinterpret_cast
<
int4
*>
(
output
);
hipLaunchKernelGGL
(
HIP_KERNEL_NAME
(
restorePadding
<
int4
>
),
grid_size
,
kMAX_THREADS_PER_BLOCK
,
0
,
stream
,
output2
,
input2
,
token_offset
,
width
,
token_count
);
}
else
if
(
hidden_size
%
2
==
0
)
{
const
int
width
=
hidden_size
/
2
;
const
int64_t
*
input2
=
reinterpret_cast
<
const
int64_t
*>
(
input
);
int64_t
*
output2
=
reinterpret_cast
<
int64_t
*>
(
output
);
hipLaunchKernelGGL
(
HIP_KERNEL_NAME
(
restorePadding
<
int64_t
>
),
grid_size
,
kMAX_THREADS_PER_BLOCK
,
0
,
stream
,
output2
,
input2
,
token_offset
,
width
,
token_count
);
}
else
{
const
int
width
=
hidden_size
;
const
int32_t
*
input2
=
reinterpret_cast
<
const
int32_t
*>
(
input
);
int32_t
*
output2
=
reinterpret_cast
<
int32_t
*>
(
output
);
hipLaunchKernelGGL
(
HIP_KERNEL_NAME
(
restorePadding
<
int32_t
>
),
grid_size
,
kMAX_THREADS_PER_BLOCK
,
0
,
stream
,
output2
,
input2
,
token_offset
,
width
,
token_count
);
}
}
template
<
>
void
LaunchRestorePadding
(
half
*
output
,
const
half
*
input
,
const
int
*
token_offset
,
const
int
token_count
,
const
int
hidden_size
,
const
int
batch_size
,
const
int
sequence_length
,
hipStream_t
stream
)
{
// input: [token_count, hidden_size]
// output: [batch_size, sequence_length, hidden_size]
ORT_ENFORCE
(
!
(
reinterpret_cast
<
size_t
>
(
input
)
&
0xF
)
&&
!
(
reinterpret_cast
<
size_t
>
(
output
)
&
0xF
),
"alignment"
);
int
grid_size
=
batch_size
*
sequence_length
;
if
(
hidden_size
%
8
==
0
)
{
const
int
width
=
hidden_size
/
8
;
const
int4
*
input2
=
reinterpret_cast
<
const
int4
*>
(
input
);
int4
*
output2
=
reinterpret_cast
<
int4
*>
(
output
);
hipLaunchKernelGGL
(
HIP_KERNEL_NAME
(
restorePadding
<
int4
>
),
grid_size
,
kMAX_THREADS_PER_BLOCK
,
0
,
stream
,
output2
,
input2
,
token_offset
,
width
,
token_count
);
}
else
if
(
hidden_size
%
4
==
0
)
{
const
int
width
=
hidden_size
/
4
;
const
int64_t
*
input2
=
reinterpret_cast
<
const
int64_t
*>
(
input
);
int64_t
*
output2
=
reinterpret_cast
<
int64_t
*>
(
output
);
hipLaunchKernelGGL
(
HIP_KERNEL_NAME
(
restorePadding
<
int64_t
>
),
grid_size
,
kMAX_THREADS_PER_BLOCK
,
0
,
stream
,
output2
,
input2
,
token_offset
,
width
,
token_count
);
}
else
if
(
hidden_size
%
2
==
0
)
{
const
int
width
=
hidden_size
/
2
;
const
int32_t
*
input2
=
reinterpret_cast
<
const
int32_t
*>
(
input
);
int32_t
*
output2
=
reinterpret_cast
<
int32_t
*>
(
output
);
hipLaunchKernelGGL
(
HIP_KERNEL_NAME
(
restorePadding
<
int32_t
>
),
grid_size
,
kMAX_THREADS_PER_BLOCK
,
0
,
stream
,
output2
,
input2
,
token_offset
,
width
,
token_count
);
}
else
{
const
int
width
=
hidden_size
;
const
int16_t
*
input2
=
reinterpret_cast
<
const
int16_t
*>
(
input
);
int16_t
*
output2
=
reinterpret_cast
<
int16_t
*>
(
output
);
hipLaunchKernelGGL
(
HIP_KERNEL_NAME
(
restorePadding
<
int16_t
>
),
grid_size
,
kMAX_THREADS_PER_BLOCK
,
0
,
stream
,
output2
,
input2
,
token_offset
,
width
,
token_count
);
}
}
__global__
void
__launch_bounds__
(
kMAX_THREADS_PER_BLOCK
)
getTrtSequenceOffset
(
int
*
trt_mha_padding_offset
,
const
int
*
sequence_token_count
,
const
int
batch_size
)
{
extern
__shared__
int
tmp_offset
[];
if
(
threadIdx
.
x
==
0
)
{
tmp_offset
[
0
]
=
0
;
for
(
int
i
=
0
;
i
<
batch_size
;
i
++
)
{
tmp_offset
[
i
+
1
]
=
tmp_offset
[
i
]
+
sequence_token_count
[
i
];
}
}
__syncthreads
();
for
(
int
i
=
threadIdx
.
x
;
i
<
batch_size
+
1
;
i
+=
blockDim
.
x
)
{
trt_mha_padding_offset
[
i
]
=
tmp_offset
[
i
];
}
}
// Get sequence offset for TensorRT fused attention when there is no padding (or padding is removed)
void
LaunchTrtSequenceOffset
(
int
*
trt_mha_padding_offset
,
const
int
*
sequence_token_count
,
const
int
batch_size
,
hipStream_t
stream
)
{
hipLaunchKernelGGL
(
getTrtSequenceOffset
,
1
,
kMAX_THREADS_PER_BLOCK
,
sizeof
(
int
)
*
(
batch_size
+
1
),
stream
,
trt_mha_padding_offset
,
sequence_token_count
,
batch_size
);
}
__global__
void
__launch_bounds__
(
kMAX_THREADS_PER_BLOCK
)
getTrtSequenceOffset
(
int
*
trt_mha_padding_offset
,
const
int
*
sequence_token_count
,
const
int
batch_size
,
const
int
sequence_length
)
{
extern
__shared__
int
tmp_offset
[];
if
(
threadIdx
.
x
==
0
)
{
tmp_offset
[
0
]
=
0
;
// B for fused attention is 2 * batch_size
for
(
int
i
=
0
;
i
<
batch_size
;
i
++
)
{
tmp_offset
[
i
*
2
+
1
]
=
tmp_offset
[
i
*
2
]
+
sequence_token_count
[
i
];
tmp_offset
[
i
*
2
+
2
]
=
sequence_length
*
(
i
+
1
);
}
}
__syncthreads
();
for
(
int
i
=
threadIdx
.
x
;
i
<
2
*
batch_size
+
1
;
i
+=
blockDim
.
x
)
{
trt_mha_padding_offset
[
i
]
=
tmp_offset
[
i
];
}
}
// Get sequence offset for TensorRT fused attention when we keep the padding
void
LaunchTrtSequenceOffset
(
int
*
trt_mha_padding_offset
,
const
int
*
sequence_token_count
,
const
int
batch_size
,
const
int
sequence_length
,
hipStream_t
stream
)
{
hipLaunchKernelGGL
(
getTrtSequenceOffset
,
1
,
kMAX_THREADS_PER_BLOCK
,
sizeof
(
int
)
*
(
2
*
batch_size
+
1
),
stream
,
trt_mha_padding_offset
,
sequence_token_count
,
batch_size
,
sequence_length
);
}
}
// namespace rocm
}
// namespace contrib
}
// namespace onnxruntime
build/Linux/Release/amdgpu/onnxruntime/contrib_ops/rocm/bert/bert_padding.h
0 → 100644
View file @
1a91fcc2
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include "core/providers/rocm/shared_inc/rocm_utils.h"
#include <hip/hip_fp16.h>
#include <rocblas.h>
namespace
onnxruntime
{
namespace
contrib
{
namespace
rocm
{
// Build token indice for non-padding tokens and padding tokens.
void
LaunchGetTokenOffset
(
int
*
token_count_buffer
,
int
*
token_offset
,
int
*
cumulated_token_count
,
const
int
*
sequence_token_count
,
const
int
batch_size
,
const
int
sequence_length
,
hipStream_t
stream
);
// Remove paddings from input.
template
<
typename
T
>
void
LaunchRemovePadding
(
T
*
output
,
const
T
*
input
,
const
int
*
token_offset
,
const
int
token_count
,
const
int
hidden_size
,
hipStream_t
stream
);
// Rebuild paddings to restore output shape.
template
<
typename
T
>
void
LaunchRestorePadding
(
T
*
output
,
const
T
*
input
,
const
int
*
token_offset
,
const
int
token_count
,
const
int
hidden_size
,
const
int
batch_size
,
const
int
sequence_length
,
hipStream_t
stream
);
// Padding offset for TensorRT fused attention kernel
void
LaunchTrtSequenceOffset
(
int
*
trt_mha_padding_offset
,
const
int
*
mask_index
,
const
int
batch_size
,
hipStream_t
stream
);
void
LaunchTrtSequenceOffset
(
int
*
trt_mha_padding_offset
,
const
int
*
mask_index
,
const
int
batch_size
,
const
int
sequence_length
,
hipStream_t
stream
);
}
// namespace rocm
}
// namespace contrib
}
// namespace onnxruntime
build/Linux/Release/amdgpu/onnxruntime/contrib_ops/rocm/bert/decoder_attention.cc
0 → 100644
View file @
1a91fcc2
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "contrib_ops/rocm/bert/attention_impl.h"
#include "contrib_ops/rocm/bert/decoder_attention.h"
#include "contrib_ops/rocm/bert/transformer_rocm_common.h"
#include "core/framework/op_kernel.h"
#include "core/providers/rocm/shared_inc/fpgeneric.h"
using
namespace
onnxruntime
::
rocm
;
using
namespace
::
onnxruntime
::
common
;
using
namespace
ONNX_NAMESPACE
;
namespace
onnxruntime
{
namespace
contrib
{
namespace
rocm
{
#define REGISTER_KERNEL_TYPED(T) \
ONNX_OPERATOR_TYPED_KERNEL_EX( \
DecoderAttention, \
kMSDomain, \
1, \
T, \
kRocmExecutionProvider, \
(*KernelDefBuilder::Create()) \
.TypeConstraint("T", DataTypeImpl::GetTensorType<T>()), \
DecoderAttention<T>);
REGISTER_KERNEL_TYPED
(
float
)
REGISTER_KERNEL_TYPED
(
MLFloat16
)
namespace
{
Status
CheckInputs
(
const
TensorShape
&
query_shape
,
const
TensorShape
&
key_shape
,
const
TensorShape
&
q_weights_shape
,
const
TensorShape
&
kv_weights_shape
,
const
TensorShape
&
bias_shape
,
const
Tensor
*
key_padding_mask
,
const
Tensor
*
key_cache
,
const
Tensor
*
value_cache
,
const
bool
static_kv
,
const
bool
use_past
,
const
bool
has_layer_state
,
const
bool
has_key_padding_mask
)
{
const
auto
&
query_shape_dims
=
query_shape
.
GetDims
();
if
(
query_shape_dims
.
size
()
!=
3
)
{
return
ORT_MAKE_STATUS
(
ONNXRUNTIME
,
INVALID_ARGUMENT
,
"Input 'query' is expected to have 3 dimensions, got "
,
query_shape_dims
.
size
());
}
int
sequence_length
=
static_cast
<
int
>
(
query_shape_dims
[
0
]);
int
batch_size
=
static_cast
<
int
>
(
query_shape_dims
[
1
]);
int
hidden_size
=
static_cast
<
int
>
(
query_shape_dims
[
2
]);
const
auto
&
key_shape_dims
=
key_shape
.
GetDims
();
if
(
key_shape_dims
.
size
()
!=
3
)
{
return
ORT_MAKE_STATUS
(
ONNXRUNTIME
,
INVALID_ARGUMENT
,
"Input 'key' is expected to have 3 dimensions, got "
,
key_shape_dims
.
size
());
}
int
kv_sequence_length
=
static_cast
<
int
>
(
key_shape_dims
[
0
]);
if
(
query_shape_dims
[
1
]
!=
key_shape_dims
[
1
])
{
return
ORT_MAKE_STATUS
(
ONNXRUNTIME
,
INVALID_ARGUMENT
,
"query and key shall have the same batch size"
);
}
if
(
query_shape_dims
[
2
]
!=
key_shape_dims
[
2
])
{
return
ORT_MAKE_STATUS
(
ONNXRUNTIME
,
INVALID_ARGUMENT
,
"query and key shall have the same hidden size"
);
}
const
auto
&
q_weights_dims
=
q_weights_shape
.
GetDims
();
if
(
q_weights_dims
.
size
()
!=
2
)
{
return
ORT_MAKE_STATUS
(
ONNXRUNTIME
,
INVALID_ARGUMENT
,
"Input 'q_weights' is expected to have 2 dimensions, got "
,
q_weights_dims
.
size
());
}
const
auto
&
kv_weights_dims
=
kv_weights_shape
.
GetDims
();
if
(
kv_weights_dims
.
size
()
!=
2
)
{
return
ORT_MAKE_STATUS
(
ONNXRUNTIME
,
INVALID_ARGUMENT
,
"Input 'kv_weights' is expected to have 2 dimensions, got "
,
kv_weights_dims
.
size
());
}
if
(
q_weights_dims
[
0
]
!=
hidden_size
||
q_weights_dims
[
1
]
!=
hidden_size
)
{
return
ORT_MAKE_STATUS
(
ONNXRUNTIME
,
INVALID_ARGUMENT
,
"q_weights shall have shape (hidden size, hidden size)"
);
}
if
(
kv_weights_dims
[
0
]
!=
hidden_size
||
kv_weights_dims
[
1
]
!=
2
*
static_cast
<
int64_t
>
(
hidden_size
))
{
return
ORT_MAKE_STATUS
(
ONNXRUNTIME
,
INVALID_ARGUMENT
,
"kv_weights shall have shape (hidden size, 2 * hidden size)"
);
}
const
auto
&
bias_dims
=
bias_shape
.
GetDims
();
if
(
bias_dims
.
size
()
!=
1
)
{
return
ORT_MAKE_STATUS
(
ONNXRUNTIME
,
INVALID_ARGUMENT
,
"Input 'bias' is expected to have 1 dimension, got "
,
bias_dims
.
size
());
}
if
(
bias_dims
[
0
]
!=
3
*
static_cast
<
int64_t
>
(
hidden_size
))
{
return
ORT_MAKE_STATUS
(
ONNXRUNTIME
,
INVALID_ARGUMENT
,
"bias shall have shape (3 * hidden size)"
);
}
int
key_length
=
kv_sequence_length
;
if
(
key_padding_mask
!=
nullptr
&&
has_key_padding_mask
==
true
)
{
const
auto
&
kp_mask_dims
=
key_padding_mask
->
Shape
().
GetDims
();
if
(
kp_mask_dims
.
size
()
!=
2
)
{
return
ORT_MAKE_STATUS
(
ONNXRUNTIME
,
INVALID_ARGUMENT
,
"Input 'key_padding_mask' is expected to have 2 dimension, got "
,
kp_mask_dims
.
size
());
}
if
(
kp_mask_dims
[
0
]
!=
batch_size
)
{
return
ORT_MAKE_STATUS
(
ONNXRUNTIME
,
INVALID_ARGUMENT
,
"key_padding_mask shall have same batch size with query"
);
}
if
(
!
has_layer_state
||
!
use_past
)
{
if
(
!
static_kv
)
{
key_length
=
sequence_length
;
}
}
else
{
if
(
!
static_kv
)
{
key_length
=
sequence_length
+
kv_sequence_length
;
}
}
if
(
kp_mask_dims
[
1
]
!=
key_length
)
{
return
ORT_MAKE_STATUS
(
ONNXRUNTIME
,
INVALID_ARGUMENT
,
"key_padding_mask shall have same sequence length as generated key"
);
}
}
if
(
key_cache
!=
nullptr
&&
value_cache
!=
nullptr
&&
has_layer_state
&&
use_past
)
{
const
auto
&
key_cache_dims
=
key_cache
->
Shape
().
GetDims
();
if
(
key_cache_dims
.
size
()
!=
4
)
{
return
ORT_MAKE_STATUS
(
ONNXRUNTIME
,
INVALID_ARGUMENT
,
"Input 'key_cache' is expected to have 4 dimension, got "
,
key_cache_dims
.
size
());
}
const
auto
&
value_cache_dims
=
value_cache
->
Shape
().
GetDims
();
if
(
value_cache_dims
.
size
()
!=
4
)
{
return
ORT_MAKE_STATUS
(
ONNXRUNTIME
,
INVALID_ARGUMENT
,
"Input 'value_cache' is expected to have 4 dimension, got "
,
value_cache_dims
.
size
());
}
if
(
key_cache_dims
[
0
]
!=
batch_size
)
{
return
ORT_MAKE_STATUS
(
ONNXRUNTIME
,
INVALID_ARGUMENT
,
"key_cache shall have same batch size as query"
);
}
if
(
value_cache_dims
[
0
]
!=
batch_size
)
{
return
ORT_MAKE_STATUS
(
ONNXRUNTIME
,
INVALID_ARGUMENT
,
"value_cache shall have same batch size as query"
);
}
if
(
key_cache_dims
[
1
]
*
key_cache_dims
[
3
]
!=
hidden_size
)
{
return
ORT_MAKE_STATUS
(
ONNXRUNTIME
,
INVALID_ARGUMENT
,
"key_cache shall have correct hidden size"
);
}
if
(
value_cache_dims
[
1
]
*
value_cache_dims
[
3
]
!=
hidden_size
)
{
return
ORT_MAKE_STATUS
(
ONNXRUNTIME
,
INVALID_ARGUMENT
,
"value_cache shall have correct hidden size"
);
}
}
return
Status
::
OK
();
}
}
// anonymous namespace
template
<
typename
T
>
DecoderAttention
<
T
>::
DecoderAttention
(
const
OpKernelInfo
&
info
)
:
RocmKernel
(
info
)
{
int64_t
num_heads
=
0
;
ORT_ENFORCE
(
info
.
GetAttr
(
"num_heads"
,
&
num_heads
).
IsOK
()
&&
num_heads
>
0
);
num_heads_
=
static_cast
<
int
>
(
num_heads
);
}
template
<
typename
T
>
Status
DecoderAttention
<
T
>::
ComputeInternal
(
OpKernelContext
*
context
)
const
{
const
Tensor
*
query
(
context
->
Input
<
Tensor
>
(
0
));
const
Tensor
*
key
(
context
->
Input
<
Tensor
>
(
1
));
const
Tensor
*
q_weights
(
context
->
Input
<
Tensor
>
(
2
));
const
Tensor
*
kv_weights
(
context
->
Input
<
Tensor
>
(
3
));
const
Tensor
*
bias
(
context
->
Input
<
Tensor
>
(
4
));
const
Tensor
*
key_padding_mask
(
context
->
Input
<
Tensor
>
(
5
));
const
Tensor
*
key_cache
(
context
->
Input
<
Tensor
>
(
6
));
const
Tensor
*
value_cache
(
context
->
Input
<
Tensor
>
(
7
));
const
Tensor
*
static_kv
(
context
->
Input
<
Tensor
>
(
8
));
const
Tensor
*
use_past
(
context
->
Input
<
Tensor
>
(
9
));
const
Tensor
*
has_layer_state
(
context
->
Input
<
Tensor
>
(
10
));
const
Tensor
*
has_key_padding_mask
(
context
->
Input
<
Tensor
>
(
11
));
hipStream_t
stream
=
Stream
();
// Copy static_kv, use_past and has_layer_state to CPU
auto
pinned_buffer
=
AllocateBufferOnCPUPinned
<
void
>
(
4
*
sizeof
(
bool
));
bool
*
kernel_state_pinned
=
reinterpret_cast
<
bool
*>
(
pinned_buffer
.
get
());
HIP_RETURN_IF_ERROR
(
hipMemcpyAsync
(
kernel_state_pinned
,
static_kv
->
Data
<
bool
>
(),
sizeof
(
bool
),
hipMemcpyDeviceToHost
,
stream
));
HIP_RETURN_IF_ERROR
(
hipMemcpyAsync
(
kernel_state_pinned
+
1
,
use_past
->
Data
<
bool
>
(),
sizeof
(
bool
),
hipMemcpyDeviceToHost
,
stream
));
HIP_RETURN_IF_ERROR
(
hipMemcpyAsync
(
kernel_state_pinned
+
2
,
has_layer_state
->
Data
<
bool
>
(),
sizeof
(
bool
),
hipMemcpyDeviceToHost
,
stream
));
HIP_RETURN_IF_ERROR
(
hipMemcpyAsync
(
kernel_state_pinned
+
3
,
has_key_padding_mask
->
Data
<
bool
>
(),
sizeof
(
bool
),
hipMemcpyDeviceToHost
,
stream
));
// Create an event to make sure the async copy is finished before reading the data.
AutoDestoryCudaEvent
new_event
;
hipEvent_t
&
isCopyDone
=
new_event
.
Get
();
HIP_RETURN_IF_ERROR
(
hipEventCreate
(
&
isCopyDone
));
HIP_RETURN_IF_ERROR
(
hipEventRecord
(
isCopyDone
,
stream
));
auto
&
device_prop
=
GetDeviceProp
();
// query shape (batch_size, sequence_length, input_hidden_size)
const
auto
&
query_shape
=
query
->
Shape
();
int
sequence_length
=
static_cast
<
int
>
(
query_shape
[
0
]);
int
batch_size
=
static_cast
<
int
>
(
query_shape
[
1
]);
int
hidden_size
=
static_cast
<
int
>
(
query_shape
[
2
]);
const
auto
&
key_shape
=
key
->
Shape
();
int
key_sequence_length
=
static_cast
<
int
>
(
key_shape
[
0
]);
int
head_size
=
hidden_size
/
num_heads_
;
//k, v sequence after gemm
int
kv_sequence_length
=
0
;
// Generate q, k, v w/o cache
// query input: (S, B, h1)
// key input: (S', B, h1)
// weight: (h1, h2)
// h = N*H
rocblas_handle
rocblas
=
RocblasHandle
();
ROCBLAS_RETURN_IF_ERROR
(
rocblas_set_stream
(
rocblas
,
stream
));
constexpr
size_t
element_size
=
sizeof
(
T
);
typedef
typename
ToHipType
<
T
>::
MappedType
HipT
;
HipT
one
=
ToHipType
<
T
>::
FromFloat
(
1.0
f
);
HipT
zero
=
ToHipType
<
T
>::
FromFloat
(
0.0
f
);
int
m
=
0
,
n
=
0
,
k
=
0
;
IAllocatorUniquePtr
<
T
>
gemm_query_buffer_p
(
nullptr
);
IAllocatorUniquePtr
<
T
>
gemm_kv_buffer_p
(
nullptr
);
HIP_RETURN_IF_ERROR
(
hipEventSynchronize
(
isCopyDone
));
bool
static_kv_
=
*
kernel_state_pinned
;
bool
use_past_
=
*
(
kernel_state_pinned
+
1
);
bool
has_layer_state_
=
*
(
kernel_state_pinned
+
2
);
bool
has_key_padding_mask_
=
*
(
kernel_state_pinned
+
3
);
ORT_RETURN_IF_ERROR
(
CheckInputs
(
query
->
Shape
(),
key
->
Shape
(),
q_weights
->
Shape
(),
kv_weights
->
Shape
(),
bias
->
Shape
(),
key_padding_mask
,
key_cache
,
value_cache
,
static_kv_
,
use_past_
,
has_layer_state_
,
has_key_padding_mask_
)
);
// calculate q
gemm_query_buffer_p
=
GetScratchBuffer
<
T
>
(
batch_size
*
sequence_length
*
hidden_size
*
element_size
);
m
=
sequence_length
*
batch_size
;
n
=
hidden_size
;
k
=
hidden_size
;
// TODO(tianleiwu): fuse bias and transpose
// broadcast bias for query: (h2, S*B)
ROCBLAS_RETURN_IF_ERROR
(
rocblasGemmHelper
(
rocblas
,
rocblas_operation_none
,
rocblas_operation_none
,
n
,
m
,
1
,
&
one
,
reinterpret_cast
<
const
HipT
*>
(
bias
->
Data
<
T
>
()),
n
,
GetConstOnes
<
HipT
>
(
m
),
1
,
&
zero
,
reinterpret_cast
<
HipT
*>
(
gemm_query_buffer_p
.
get
()),
n
,
device_prop
));
// matmul: (h2, h1)*(h1, S*B)
ROCBLAS_RETURN_IF_ERROR
(
rocblasGemmHelper
(
rocblas
,
rocblas_operation_none
,
rocblas_operation_none
,
n
,
m
,
k
,
&
one
,
reinterpret_cast
<
const
HipT
*>
(
q_weights
->
Data
<
T
>
()),
n
,
reinterpret_cast
<
const
HipT
*>
(
query
->
Data
<
T
>
()),
k
,
&
one
,
reinterpret_cast
<
HipT
*>
(
gemm_query_buffer_p
.
get
()),
n
,
device_prop
));
// gemm_query_buffer in col-base: (h2, S*B)
// calcualte k, v
n
=
2
*
hidden_size
;
k
=
hidden_size
;
if
(
!
has_layer_state_
||
!
use_past_
)
{
if
(
!
static_kv_
)
{
gemm_kv_buffer_p
=
GetScratchBuffer
<
T
>
(
batch_size
*
2
*
sequence_length
*
hidden_size
*
element_size
);
m
=
sequence_length
*
batch_size
;
n
=
2
*
hidden_size
;
k
=
hidden_size
;
kv_sequence_length
=
sequence_length
;
// broadcast bias for key and value: (2*h2, T_S*B)
ROCBLAS_RETURN_IF_ERROR
(
rocblasGemmHelper
(
rocblas
,
rocblas_operation_none
,
rocblas_operation_none
,
n
,
m
,
1
,
&
one
,
reinterpret_cast
<
const
HipT
*>
(
bias
->
Data
<
T
>
()
+
hidden_size
),
n
,
GetConstOnes
<
HipT
>
(
m
),
1
,
&
zero
,
reinterpret_cast
<
HipT
*>
(
gemm_kv_buffer_p
.
get
()),
n
,
device_prop
));
// matmul: (2*h2, h1)*(h1, T_S*B)
ROCBLAS_RETURN_IF_ERROR
(
rocblasGemmHelper
(
rocblas
,
rocblas_operation_none
,
rocblas_operation_none
,
n
,
m
,
k
,
&
one
,
reinterpret_cast
<
const
HipT
*>
(
kv_weights
->
Data
<
T
>
()),
n
,
reinterpret_cast
<
const
HipT
*>
(
query
->
Data
<
T
>
()),
k
,
&
one
,
reinterpret_cast
<
HipT
*>
(
gemm_kv_buffer_p
.
get
()),
n
,
device_prop
));
// gemm_kv_buffer in col-base: (2*h2, T_S*B)
}
else
{
gemm_kv_buffer_p
=
GetScratchBuffer
<
T
>
(
batch_size
*
2
*
key_sequence_length
*
hidden_size
*
element_size
);
m
=
key_sequence_length
*
batch_size
;
n
=
2
*
hidden_size
;
k
=
hidden_size
;
kv_sequence_length
=
key_sequence_length
;
// broadcast bias for key and value: (2*h2, T_S*B)
ROCBLAS_RETURN_IF_ERROR
(
rocblasGemmHelper
(
rocblas
,
rocblas_operation_none
,
rocblas_operation_none
,
n
,
m
,
1
,
&
one
,
reinterpret_cast
<
const
HipT
*>
(
bias
->
Data
<
T
>
()
+
hidden_size
),
n
,
GetConstOnes
<
HipT
>
(
m
),
1
,
&
zero
,
reinterpret_cast
<
HipT
*>
(
gemm_kv_buffer_p
.
get
()),
n
,
device_prop
));
// matmul: (2*h2, h1)*(h1, T_S*B)
ROCBLAS_RETURN_IF_ERROR
(
rocblasGemmHelper
(
rocblas
,
rocblas_operation_none
,
rocblas_operation_none
,
n
,
m
,
k
,
&
one
,
reinterpret_cast
<
const
HipT
*>
(
kv_weights
->
Data
<
T
>
()),
n
,
reinterpret_cast
<
const
HipT
*>
(
key
->
Data
<
T
>
()),
k
,
&
one
,
reinterpret_cast
<
HipT
*>
(
gemm_kv_buffer_p
.
get
()),
n
,
device_prop
));
// gemm_kv_buffer in col-base: (2*h2, T_S*B)
}
}
else
{
ORT_ENFORCE
(
nullptr
!=
key_cache
&&
nullptr
!=
value_cache
);
// (B, N, S, H)
const
auto
&
cache_shape
=
key_cache
->
Shape
();
// key and value cache have identical shape
int
cache_sequence_length
=
static_cast
<
int
>
(
cache_shape
[
2
]);
if
(
!
static_kv_
)
{
gemm_kv_buffer_p
=
GetScratchBuffer
<
T
>
(
batch_size
*
2
*
sequence_length
*
hidden_size
*
element_size
);
m
=
sequence_length
*
batch_size
;
kv_sequence_length
=
cache_sequence_length
+
sequence_length
;
// broadcast bias for key and value: (2*h2, T_S*B)
ROCBLAS_RETURN_IF_ERROR
(
rocblasGemmHelper
(
rocblas
,
rocblas_operation_none
,
rocblas_operation_none
,
n
,
m
,
1
,
&
one
,
reinterpret_cast
<
const
HipT
*>
(
bias
->
Data
<
T
>
()
+
hidden_size
),
n
,
GetConstOnes
<
HipT
>
(
m
),
1
,
&
zero
,
reinterpret_cast
<
HipT
*>
(
gemm_kv_buffer_p
.
get
()),
n
,
device_prop
));
// matmul: (2*h2, h1)*(h1, T_S*B)
ROCBLAS_RETURN_IF_ERROR
(
rocblasGemmHelper
(
rocblas
,
rocblas_operation_none
,
rocblas_operation_none
,
n
,
m
,
k
,
&
one
,
reinterpret_cast
<
const
HipT
*>
(
kv_weights
->
Data
<
T
>
()),
n
,
reinterpret_cast
<
const
HipT
*>
(
query
->
Data
<
T
>
()),
k
,
&
one
,
reinterpret_cast
<
HipT
*>
(
gemm_kv_buffer_p
.
get
()),
n
,
device_prop
));
// gemm_kv_buffer in col-base: (2*h2, T_S*B)
}
else
{
kv_sequence_length
=
cache_sequence_length
;
}
}
size_t
bytes
=
element_size
*
batch_size
*
(
sequence_length
+
2
*
kv_sequence_length
)
*
hidden_size
;
auto
qkv_buffer_p
=
GetScratchBuffer
<
void
>
(
bytes
);
bytes
=
element_size
*
2
*
batch_size
*
sequence_length
*
num_heads_
*
(
2
*
head_size
+
kv_sequence_length
);
auto
workspace_p
=
GetScratchBuffer
<
void
>
(
bytes
);
Tensor
*
output
(
context
->
Output
(
0
,
query_shape
));
TensorShape
new_cache_shape
({
batch_size
,
num_heads_
,
kv_sequence_length
,
head_size
});
Tensor
*
new_key_cache
(
context
->
Output
(
1
,
new_cache_shape
));
Tensor
*
new_value_cache
(
context
->
Output
(
2
,
new_cache_shape
));
return
LaunchDecoderAttentionKernel
(
device_prop
,
stream
,
rocblas
,
element_size
,
batch_size
,
sequence_length
,
kv_sequence_length
,
num_heads_
,
head_size
,
static_kv_
,
use_past_
,
has_layer_state_
,
has_key_padding_mask_
,
nullptr
==
gemm_query_buffer_p
?
nullptr
:
reinterpret_cast
<
const
HipT
*>
(
gemm_query_buffer_p
.
get
()),
nullptr
==
gemm_kv_buffer_p
?
nullptr
:
reinterpret_cast
<
const
HipT
*>
(
gemm_kv_buffer_p
.
get
()),
nullptr
==
key_padding_mask
?
nullptr
:
key_padding_mask
->
Data
<
bool
>
(),
nullptr
==
key_cache
?
nullptr
:
key_cache
->
Data
<
T
>
(),
nullptr
==
value_cache
?
nullptr
:
value_cache
->
Data
<
T
>
(),
qkv_buffer_p
.
get
(),
workspace_p
.
get
(),
output
->
MutableData
<
T
>
(),
nullptr
==
new_key_cache
?
nullptr
:
new_key_cache
->
MutableData
<
T
>
(),
nullptr
==
new_value_cache
?
nullptr
:
new_value_cache
->
MutableData
<
T
>
());
}
}
// namespace rocm
}
// namespace contrib
}
// namespace onnxruntime
build/Linux/Release/amdgpu/onnxruntime/contrib_ops/rocm/bert/decoder_attention.h
0 → 100644
View file @
1a91fcc2
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include "core/common/common.h"
#include "core/providers/rocm/rocm_kernel.h"
namespace
onnxruntime
{
namespace
contrib
{
namespace
rocm
{
using
namespace
onnxruntime
::
rocm
;
template
<
typename
T
>
class
DecoderAttention
final
:
public
RocmKernel
{
public:
DecoderAttention
(
const
OpKernelInfo
&
info
);
Status
ComputeInternal
(
OpKernelContext
*
context
)
const
override
;
private:
int
num_heads_
;
};
}
// namespace rocm
}
// namespace contrib
}
// namespace onnxruntime
build/Linux/Release/amdgpu/onnxruntime/contrib_ops/rocm/bert/layer_norm.cuh
0 → 100644
View file @
1a91fcc2
#include "hip/hip_runtime.h"
/*
The implementation of this file is based on bert plugins in TensorRT demo:
https://github.com/NVIDIA/TensorRT/tree/release/5.1/demo/BERT/
Copyright 2019 NVIDIA Corporation
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include "core/providers/rocm/rocm_common.h"
#include "core/providers/rocm/cu_inc/common.cuh"
#include "core/providers/rocm/shared_inc/rocm_call.h"
#include <hip/hip_fp16.h>
#include <rocblas.h>
#include <hipcub/hipcub.hpp>
using
namespace
onnxruntime
::
rocm
;
using
namespace
hipcub
;
namespace
onnxruntime
{
namespace
contrib
{
namespace
rocm
{
template
<
typename
T
>
__device__
inline
T
Rsqrt
(
const
T
&
x
);
template
<
>
__device__
inline
float
Rsqrt
(
const
float
&
x
)
{
return
rsqrtf
(
x
);
}
template
<
>
__device__
inline
half
Rsqrt
(
const
half
&
x
)
{
#if __CUDA_ARCH__ >= 530 || !defined(__CUDA_ARCH__)
return
hrsqrt
(
x
);
#else
return
half
(
rsqrtf
(
float
(
x
)));
#endif
}
__device__
inline
half2
AddHalf2
(
const
half2
a
,
const
half2
b
)
{
#if __CUDA_ARCH__ >= 530 || !defined(__CUDA_ARCH__)
return
__hadd2
(
a
,
b
);
#else
return
__halves2half2
(
__hadd
(
a
.
x
,
b
.
x
),
__hadd
(
a
.
y
,
b
.
y
));
#endif
}
struct
KeyValuePairSum
{
__device__
inline
hipcub
::
KeyValuePair
<
float
,
float
>
operator
()(
const
hipcub
::
KeyValuePair
<
float
,
float
>&
a
,
const
hipcub
::
KeyValuePair
<
float
,
float
>&
b
)
{
return
hipcub
::
KeyValuePair
<
float
,
float
>
(
a
.
key
+
b
.
key
,
a
.
value
+
b
.
value
);
}
__device__
inline
hipcub
::
KeyValuePair
<
half
,
half
>
operator
()(
const
hipcub
::
KeyValuePair
<
half
,
half
>&
a
,
const
hipcub
::
KeyValuePair
<
half
,
half
>&
b
)
{
const
half2
a2
=
__halves2half2
(
a
.
key
,
a
.
value
);
const
half2
b2
=
__halves2half2
(
b
.
key
,
b
.
value
);
const
half2
res
=
AddHalf2
(
a2
,
b2
);
return
hipcub
::
KeyValuePair
<
half
,
half
>
(
__low2half
(
res
),
__high2half
(
res
));
}
__device__
inline
hipcub
::
KeyValuePair
<
half2
,
half2
>
operator
()(
const
hipcub
::
KeyValuePair
<
half2
,
half2
>&
a
,
const
hipcub
::
KeyValuePair
<
half2
,
half2
>&
b
)
{
return
hipcub
::
KeyValuePair
<
half2
,
half2
>
(
AddHalf2
(
a
.
key
,
b
.
key
),
AddHalf2
(
a
.
value
,
b
.
value
));
}
};
template
<
typename
T
,
int
TPB
>
__device__
inline
void
LayerNorm
(
const
hipcub
::
KeyValuePair
<
T
,
T
>&
thread_data
,
const
int
ld
,
const
int
offset
,
const
T
*
beta
,
const
T
*
gamma
,
const
T
epsilon
,
T
*
output
)
{
// Assuming thread_data is already divided by ld
using
BlockReduce
=
hipcub
::
BlockReduce
<
hipcub
::
KeyValuePair
<
T
,
T
>
,
TPB
>
;
__shared__
typename
BlockReduce
::
TempStorage
temp_storage
;
__shared__
T
mu
;
// mean
__shared__
T
rsigma
;
// 1 / std.dev.
KeyValuePairSum
pair_sum
;
const
auto
sum_kv
=
BlockReduce
(
temp_storage
).
Reduce
(
thread_data
,
pair_sum
);
if
(
threadIdx
.
x
==
0
)
{
mu
=
sum_kv
.
key
;
rsigma
=
Rsqrt
(
sum_kv
.
value
-
mu
*
mu
+
epsilon
);
}
__syncthreads
();
for
(
int
i
=
threadIdx
.
x
;
i
<
ld
;
i
+=
TPB
)
{
const
int
idx
=
offset
+
i
;
const
T
val
=
output
[
idx
];
const
T
g
(
gamma
[
i
]);
const
T
b
=
(
nullptr
==
beta
)
?
(
T
)
0
:
beta
[
i
];
output
[
idx
]
=
g
*
(
val
-
mu
)
*
rsigma
+
b
;
}
}
template
<
typename
T
,
int
TPB
,
int
ILP
>
__device__
inline
void
LayerNormSmall
(
const
T
*
input_v
,
const
hipcub
::
KeyValuePair
<
T
,
T
>&
thread_data
,
const
int
ld
,
const
int
idx
,
const
T
*
beta
,
const
T
*
gamma
,
const
T
epsilon
,
T
*
output
)
{
// Assuming thread_data is already divided by ld
// Small settings: the block covers the leading dimension TPB >= ld. The input
// value is available in a register
using
VecT
=
aligned_vector
<
T
,
ILP
>
;
using
BlockReduce
=
hipcub
::
BlockReduce
<
hipcub
::
KeyValuePair
<
T
,
T
>
,
TPB
>
;
__shared__
typename
BlockReduce
::
TempStorage
temp_storage
;
__shared__
T
mu
;
// mean
__shared__
T
rsigma
;
// 1 / std.dev.
T
beta_v
[
ILP
],
gamma_v
[
ILP
],
output_v
[
ILP
];
if
(
beta
!=
nullptr
)
{
VecT
*
beta_val
=
reinterpret_cast
<
VecT
*>
(
&
beta_v
);
*
beta_val
=
*
reinterpret_cast
<
const
VecT
*>
(
&
beta
[
threadIdx
.
x
*
ILP
]);
}
VecT
*
gamma_val
=
reinterpret_cast
<
VecT
*>
(
&
gamma_v
);
*
gamma_val
=
*
reinterpret_cast
<
const
VecT
*>
(
&
gamma
[
threadIdx
.
x
*
ILP
]);
VecT
*
output_val
=
reinterpret_cast
<
VecT
*>
(
&
output_v
);
KeyValuePairSum
pair_sum
;
const
hipcub
::
KeyValuePair
<
T
,
T
>
sum_kv
=
BlockReduce
(
temp_storage
).
Reduce
(
thread_data
,
pair_sum
);
if
(
threadIdx
.
x
==
0
)
{
mu
=
sum_kv
.
key
;
rsigma
=
Rsqrt
(
sum_kv
.
value
-
mu
*
mu
+
epsilon
);
}
__syncthreads
();
if
(
ILP
*
threadIdx
.
x
<
ld
)
{
#pragma unroll
for
(
int
i
=
0
;
i
<
ILP
;
i
++
)
{
output_v
[
i
]
=
(
beta
!=
nullptr
)
?
gamma_v
[
i
]
*
(
input_v
[
i
]
-
mu
)
*
rsigma
+
beta_v
[
i
]
:
gamma_v
[
i
]
*
(
input_v
[
i
]
-
mu
)
*
rsigma
;
}
*
(
reinterpret_cast
<
VecT
*>
(
&
output
[
idx
]))
=
*
output_val
;
}
}
}
// namespace rocm
}
// namespace contrib
}
// namespace onnxruntime
build/Linux/Release/amdgpu/onnxruntime/contrib_ops/rocm/bert/longformer_attention.cc
0 → 100644
View file @
1a91fcc2
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "core/providers/rocm/shared_inc/fpgeneric.h"
#include "core/platform/env_var_utils.h"
#include "contrib_ops/rocm/bert/longformer_global_impl.h"
#include "contrib_ops/rocm/bert/longformer_attention_impl.h"
#include "contrib_ops/rocm/bert/transformer_rocm_common.h"
#include "contrib_ops/rocm/bert/longformer_attention.h"
using
namespace
onnxruntime
::
rocm
;
using
namespace
::
onnxruntime
::
common
;
using
namespace
ONNX_NAMESPACE
;
namespace
onnxruntime
{
namespace
contrib
{
namespace
rocm
{
#define REGISTER_KERNEL_TYPED(T) \
ONNX_OPERATOR_TYPED_KERNEL_EX( \
LongformerAttention, \
kMSDomain, \
1, \
T, \
kRocmExecutionProvider, \
(*KernelDefBuilder::Create()) \
.TypeConstraint("T", DataTypeImpl::GetTensorType<T>()), \
LongformerAttention<T>);
REGISTER_KERNEL_TYPED
(
float
)
REGISTER_KERNEL_TYPED
(
MLFloat16
)
template
<
typename
T
>
LongformerAttention
<
T
>::
LongformerAttention
(
const
OpKernelInfo
&
info
)
:
RocmKernel
(
info
),
LongformerAttentionBase
(
info
)
{
use_compact_memory_
=
ParseEnvironmentVariableWithDefault
<
bool
>
(
longformer
::
kUseCompactMemory
,
true
);
use_half4_
=
ParseEnvironmentVariableWithDefault
<
bool
>
(
longformer
::
kUseHalf4
,
true
);
}
template
<
typename
T
>
Status
LongformerAttention
<
T
>::
ComputeInternal
(
OpKernelContext
*
context
)
const
{
const
Tensor
*
input
=
context
->
Input
<
Tensor
>
(
0
);
const
Tensor
*
weights
=
context
->
Input
<
Tensor
>
(
1
);
const
Tensor
*
bias
=
context
->
Input
<
Tensor
>
(
2
);
const
Tensor
*
attention_mask
=
context
->
Input
<
Tensor
>
(
3
);
const
Tensor
*
global_weights
=
context
->
Input
<
Tensor
>
(
4
);
const
Tensor
*
global_bias
=
context
->
Input
<
Tensor
>
(
5
);
const
Tensor
*
global_attention_mask
=
context
->
Input
<
Tensor
>
(
6
);
ORT_RETURN_IF_ERROR
(
CheckInputs
(
input
->
Shape
(),
weights
->
Shape
(),
bias
->
Shape
(),
attention_mask
->
Shape
(),
global_weights
->
Shape
(),
global_bias
->
Shape
(),
global_attention_mask
->
Shape
()));
// Input shapes:
// input : (batch_size, sequence_length, hidden_size)
// weights : (hidden_size, 3 * hidden_size) -- format 1
// (3, hidden_size, hidden_size) -- format 0
// bias : (3 * hidden_size) -- format 1 (bias for Q, K, V)
// (5 * hidden_size) -- format 0 (bias for Q, K, V, Global_K, Global_V)
// attention_mask : (batch_size, sequence_length)
// global_weights : (hidden_size, 3 * hidden_size) -- format 1
// (3, hidden_size, hidden_size) -- format 0
// global_bias : (3 * hidden_size) -- format 1 (bias for Global_Q, Global_K, Global_V)
// (1 * hidden_size) -- format 0 (bias for Global_Q)
// global_attention_mask : (batch_size, sequence_length)
// Output shapes:
// output : (batch_size, sequence_length, hidden_size)
const
auto
&
shape
=
input
->
Shape
();
int
batch_size
=
static_cast
<
int
>
(
shape
[
0
]);
int
sequence_length
=
static_cast
<
int
>
(
shape
[
1
]);
int
hidden_size
=
static_cast
<
int
>
(
shape
[
2
]);
int
head_size
=
hidden_size
/
num_heads_
;
Tensor
*
output
=
context
->
Output
(
0
,
shape
);
rocblas_handle
rocblas
=
RocblasHandle
();
hipStream_t
stream
=
Stream
();
ROCBLAS_RETURN_IF_ERROR
(
rocblas_set_stream
(
rocblas
,
stream
));
constexpr
size_t
element_size
=
sizeof
(
T
);
// TODO(tianleiwu): only calculate global index once per model instead of once per LongformerAttention node.
// Build Global Index
auto
global_index_buffer
=
GetScratchBuffer
<
int
>
(
static_cast
<
size_t
>
(
batch_size
)
*
sequence_length
);
auto
batch_global_num_buffer
=
GetScratchBuffer
<
int
>
(
batch_size
);
size_t
global_scratch_bytes
=
GetGlobalScratchSize
(
sequence_length
);
auto
global_scratch_buffer
=
GetScratchBuffer
<
void
>
(
global_scratch_bytes
);
auto
&
device_prop
=
GetDeviceProp
();
ORT_RETURN_IF_ERROR
(
BuildGlobalIndex
(
device_prop
,
stream
,
global_attention_mask
->
Data
<
int
>
(),
batch_size
,
sequence_length
,
global_index_buffer
.
get
(),
batch_global_num_buffer
.
get
(),
global_scratch_buffer
.
get
(),
global_scratch_bytes
));
// Copy batch_global_num to CPU
size_t
pinned_buffer_bytes
=
GetPinnedBufferSize
(
batch_size
);
auto
pinned_buffer
=
AllocateBufferOnCPUPinned
<
void
>
(
pinned_buffer_bytes
);
int
*
batch_global_num_pinned
=
reinterpret_cast
<
int
*>
(
pinned_buffer
.
get
());
HIP_RETURN_IF_ERROR
(
hipMemcpyAsync
(
batch_global_num_pinned
,
batch_global_num_buffer
.
get
(),
batch_size
*
sizeof
(
int
),
hipMemcpyDeviceToHost
,
stream
));
// Create an event to make sure the async copy is finished before reading the data.
AutoDestoryCudaEvent
new_event
;
hipEvent_t
&
is_copy_done
=
new_event
.
Get
();
HIP_RETURN_IF_ERROR
(
hipEventCreateWithFlags
(
&
is_copy_done
,
hipEventDisableTiming
));
HIP_RETURN_IF_ERROR
(
hipEventRecord
(
is_copy_done
,
stream
));
size_t
qkv_size
=
batch_size
*
sequence_length
*
3
*
hidden_size
*
element_size
;
// Buffer for GEMM outputs of q, k, v, global_q, global_k and global_v
// TODO(tianleiwu): compact global_q only need batch_size * window * hidden_size * element_size buffer size.
auto
gemm_buffer
=
GetScratchBuffer
<
void
>
(
qkv_size
+
qkv_size
);
bool
use_merged_qkv_weights
=
(
weights
->
Shape
().
NumDimensions
()
==
2
);
int
m
=
batch_size
*
sequence_length
;
int
n
=
use_merged_qkv_weights
?
3
*
hidden_size
:
hidden_size
;
int
k
=
hidden_size
;
typedef
typename
ToHipType
<
T
>::
MappedType
HipT
;
const
HipT
*
input_data
=
reinterpret_cast
<
const
HipT
*>
(
input
->
Data
<
T
>
());
const
HipT
*
weights_data
=
reinterpret_cast
<
const
HipT
*>
(
weights
->
Data
<
T
>
());
const
HipT
*
global_weights_data
=
reinterpret_cast
<
const
HipT
*>
(
global_weights
->
Data
<
T
>
());
float
one
=
1.0
f
;
float
zero
=
0.0
f
;
if
(
use_merged_qkv_weights
)
{
// Gemm, note that ROCM assumes col-major, so result(N, M) = 1 * weights x input + 0 x B.
ROCBLAS_RETURN_IF_ERROR
(
rocblasGemmHelper
(
rocblas
,
rocblas_operation_none
,
rocblas_operation_none
,
n
,
m
,
k
,
&
one
,
weights_data
,
n
,
input_data
,
k
,
&
zero
,
reinterpret_cast
<
HipT
*>
(
gemm_buffer
.
get
()),
n
,
device_prop
));
}
else
{
// q
const
HipT
*
q_weight
=
weights_data
;
HipT
*
q_data
=
reinterpret_cast
<
HipT
*>
(
gemm_buffer
.
get
());
ROCBLAS_RETURN_IF_ERROR
(
rocblasGemmHelper
(
rocblas
,
rocblas_operation_none
,
rocblas_operation_none
,
n
,
m
,
k
,
&
one
,
q_weight
,
n
,
input_data
,
k
,
&
zero
,
q_data
,
n
,
device_prop
));
// k
const
HipT
*
k_weight
=
q_weight
+
hidden_size
*
hidden_size
;
HipT
*
k_data
=
q_data
+
batch_size
*
sequence_length
*
hidden_size
;
ROCBLAS_RETURN_IF_ERROR
(
rocblasGemmHelper
(
rocblas
,
rocblas_operation_none
,
rocblas_operation_none
,
n
,
m
,
k
,
&
one
,
k_weight
,
n
,
input_data
,
k
,
&
zero
,
k_data
,
n
,
device_prop
));
// v
const
HipT
*
v_weight
=
k_weight
+
hidden_size
*
hidden_size
;
HipT
*
v_data
=
k_data
+
batch_size
*
sequence_length
*
hidden_size
;
ROCBLAS_RETURN_IF_ERROR
(
rocblasGemmHelper
(
rocblas
,
rocblas_operation_none
,
rocblas_operation_none
,
n
,
m
,
k
,
&
one
,
v_weight
,
n
,
input_data
,
k
,
&
zero
,
v_data
,
n
,
device_prop
));
}
// Wait for async copy of batch_global_num
HIP_RETURN_IF_ERROR
(
hipEventSynchronize
(
is_copy_done
));
// Find the maximum number of global tokens in all batches
int
max_num_global
=
0
;
for
(
int
i
=
0
;
i
<
batch_size
;
++
i
)
{
if
(
max_num_global
<
batch_global_num_pinned
[
i
])
{
max_num_global
=
batch_global_num_pinned
[
i
];
}
}
// Do not use compact memory kernel in the following situations:
// (1) global tokens > windows size, compact memory kernel cannot be used due to its assumptions.
// (2) sequence_length == 2 * attention_window, compact memory kernel has parity issue.
// (3) user sets environment variable ORT_LONGFORMER_COMPACT_MEMORY=0
bool
disable_compact_memory
=
(
max_num_global
>
window_
||
sequence_length
==
2
*
window_
||
!
use_compact_memory_
);
// Fully connection for global projection.
// Note that Q only need handle global query tokens if we split GEMM to global Q/K/V separately.
// When there is no global token, need not run global GEMM.
HipT
*
global_gemm_buffer
=
nullptr
;
if
(
max_num_global
>
0
)
{
global_gemm_buffer
=
reinterpret_cast
<
HipT
*>
(
reinterpret_cast
<
char
*>
(
gemm_buffer
.
get
())
+
qkv_size
);
if
(
use_merged_qkv_weights
)
{
ROCBLAS_RETURN_IF_ERROR
(
rocblasGemmHelper
(
rocblas
,
rocblas_operation_none
,
rocblas_operation_none
,
n
,
m
,
k
,
&
one
,
reinterpret_cast
<
const
HipT
*>
(
global_weights
->
Data
<
T
>
()),
n
,
input_data
,
k
,
&
zero
,
global_gemm_buffer
,
n
,
device_prop
));
}
else
{
// global q
const
HipT
*
global_q_weight
=
global_weights_data
;
HipT
*
global_q
=
global_gemm_buffer
+
2
*
batch_size
*
sequence_length
*
hidden_size
;
if
(
disable_compact_memory
)
{
ROCBLAS_RETURN_IF_ERROR
(
rocblasGemmHelper
(
rocblas
,
rocblas_operation_none
,
rocblas_operation_none
,
n
,
m
,
k
,
&
one
,
global_q_weight
,
n
,
input_data
,
k
,
&
zero
,
global_q
,
n
,
device_prop
));
}
else
{
ROCBLAS_RETURN_IF_ERROR
(
rocblasGemmStridedBatchedHelper
(
rocblas
,
rocblas_operation_none
,
rocblas_operation_none
,
hidden_size
,
// m
max_num_global
,
// n
hidden_size
,
// k
&
one
,
// alpha
global_q_weight
,
// A
hidden_size
,
// lda
0
,
// strideA
input_data
,
// B
hidden_size
,
// ldb
sequence_length
*
hidden_size
,
// strideB
&
zero
,
// beta
global_q
,
// C
hidden_size
,
// ldc
max_num_global
*
hidden_size
,
// strideC
batch_size
,
// batch count
device_prop
));
}
// global k
const
HipT
*
global_k_weight
=
global_weights_data
+
hidden_size
*
hidden_size
;
HipT
*
global_k
=
global_gemm_buffer
;
ROCBLAS_RETURN_IF_ERROR
(
rocblasGemmHelper
(
rocblas
,
rocblas_operation_none
,
rocblas_operation_none
,
n
,
m
,
k
,
&
one
,
global_k_weight
,
n
,
input_data
,
k
,
&
zero
,
global_k
,
n
,
device_prop
));
// global v
const
HipT
*
global_v_weight
=
global_k_weight
+
hidden_size
*
hidden_size
;
HipT
*
global_v
=
global_gemm_buffer
+
batch_size
*
sequence_length
*
hidden_size
;
ROCBLAS_RETURN_IF_ERROR
(
rocblasGemmHelper
(
rocblas
,
rocblas_operation_none
,
rocblas_operation_none
,
n
,
m
,
k
,
&
one
,
global_v_weight
,
n
,
input_data
,
k
,
&
zero
,
global_v
,
n
,
device_prop
));
}
}
size_t
workSpaceSize
=
GetLongformerAttentionWorkspaceSize
(
element_size
,
batch_size
,
num_heads_
,
head_size
,
sequence_length
,
max_num_global
,
window_
,
disable_compact_memory
);
auto
workspace_buffer
=
GetScratchBuffer
<
void
>
(
workSpaceSize
);
ORT_RETURN_IF_ERROR
(
LaunchLongformerAttentionKernel
(
device_prop
,
rocblas
,
stream
,
reinterpret_cast
<
const
HipT
*>
(
gemm_buffer
.
get
()),
reinterpret_cast
<
const
HipT
*>
(
bias
->
Data
<
T
>
()),
reinterpret_cast
<
const
HipT
*>
(
attention_mask
->
Data
<
T
>
()),
reinterpret_cast
<
const
HipT
*>
(
global_gemm_buffer
),
reinterpret_cast
<
const
HipT
*>
(
global_bias
->
Data
<
T
>
()),
global_attention_mask
->
Data
<
int
>
(),
global_index_buffer
.
get
(),
batch_global_num_buffer
.
get
(),
pinned_buffer
.
get
(),
workspace_buffer
.
get
(),
output
->
MutableData
<
T
>
(),
batch_size
,
sequence_length
,
num_heads_
,
head_size
,
window_
,
max_num_global
,
element_size
,
disable_compact_memory
,
use_merged_qkv_weights
,
use_half4_
));
// Defer release of pinned memory since hipStreamSynchronize is not used here and kernel need access the buffer.
this
->
AddDeferredReleaseCPUPtr
(
pinned_buffer
.
release
());
return
Status
::
OK
();
}
}
// namespace rocm
}
// namespace contrib
}
// namespace onnxruntime
build/Linux/Release/amdgpu/onnxruntime/contrib_ops/rocm/bert/longformer_attention.h
0 → 100644
View file @
1a91fcc2
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include "core/common/common.h"
#include "core/providers/rocm/rocm_kernel.h"
#include "contrib_ops/cpu/bert/longformer_attention_base.h"
namespace
onnxruntime
{
namespace
contrib
{
namespace
rocm
{
using
namespace
onnxruntime
::
rocm
;
template
<
typename
T
>
class
LongformerAttention
final
:
public
RocmKernel
,
public
LongformerAttentionBase
{
public:
LongformerAttention
(
const
OpKernelInfo
&
info
);
Status
ComputeInternal
(
OpKernelContext
*
context
)
const
override
;
private:
bool
use_compact_memory_
;
bool
use_half4_
;
};
}
// namespace rocm
}
// namespace contrib
}
// namespace onnxruntime
build/Linux/Release/amdgpu/onnxruntime/contrib_ops/rocm/bert/longformer_attention_impl.cu
0 → 100644
View file @
1a91fcc2
#include "hip/hip_runtime.h"
/*
Copyright (c) NVIDIA Corporation and Microsoft Corporation
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
// Limitations of current Longformer Attention ROCM Kernels:
// (1) Does not support global tokens in the middle. All global tokens shall be in the beginning of sequence.
// (2) Maximum number of global tokens <= one-sided attention window
#include <hipcub/hipcub.hpp>
#include <rocblas.h>
#include <hip/hip_fp16.h>
#include <hip/hip_runtime.h>
#include <limits>
#include "core/providers/rocm/cu_inc/common.cuh"
#include "core/providers/rocm/rocm_common.h"
#include "contrib_ops/rocm/bert/add_bias_transpose.h"
#include "contrib_ops/rocm/bert/attention_impl.h"
#include "contrib_ops/rocm/bert/longformer_attention_softmax.h"
#include "contrib_ops/rocm/bert/longformer_attention_impl.h"
using
namespace
onnxruntime
::
rocm
;
using
namespace
hipcub
;
#define CHECK(expr) ROCBLAS_RETURN_IF_ERROR(expr)
#define CHECK_ROCM(expr) HIP_RETURN_IF_ERROR(expr)
namespace
onnxruntime
{
namespace
contrib
{
namespace
rocm
{
// Denote: batch size (B), sequence length (S), number of heads (N), dimension per head (H), maximum global tokens (G)
//
// Workspace layout (default data type T is float or half):
// [SoftmaxSpace] [Q:BxNxSxH] [K:BxNxSxH] [V:BxNxSxH] [Global_Q:BxNxSxH] [Global_K:BxNxSxH] [Global_V:BxNxSxH]
// where Global_Q, Global_K and Global_V are optional. They are not allocated when there is no global token.
//
// SoftmaxSpace layout is the following when compact memory is enabled:
// [scratch1: (5S-3W)*W*N*B] [scratch2: size_t 15]
// Scratch1 has 5 buffers for local and global attention calculation.
// Scratch2 has 5 input/output pointers, 5 buffer sizes and 5 strides related to scratch1.
//
// SoftmaxSpace layout is the following When compact memory is disabled:
// [scratch1: BxNxSxS] [scratch2: BxNxSxS]
static
size_t
Align
(
size_t
a
)
{
const
size_t
alignment
=
128
;
// Align on a 16-byte boundary to avoid "misaligned address" error.
return
CeilDiv
(
a
,
alignment
)
*
alignment
;
}
size_t
GetScratch1Size
(
size_t
element_size
,
size_t
batch_size
,
size_t
num_heads
,
size_t
sequence_length
,
size_t
window
)
{
size_t
bytes
=
(
5
*
sequence_length
-
3
*
window
)
*
window
*
num_heads
*
batch_size
*
element_size
;
return
Align
(
bytes
);
}
constexpr
size_t
GetScratch2Size
()
{
return
5
*
sizeof
(
void
*
)
+
10
*
sizeof
(
size_t
);
}
size_t
GetLongformerSoftmaxWorkspaceSize
(
size_t
element_size
,
size_t
batch_size
,
size_t
num_heads
,
size_t
sequence_length
,
size_t
window
,
bool
disable_compact_memory
)
{
if
(
!
disable_compact_memory
)
{
size_t
scratch1_size
=
GetScratch1Size
(
element_size
,
batch_size
,
num_heads
,
sequence_length
,
window
);
size_t
scratch2_size
=
GetScratch2Size
();
return
Align
(
scratch1_size
+
scratch2_size
);
}
else
{
return
2
*
GetAttentionScratchSize
(
element_size
,
batch_size
,
num_heads
,
sequence_length
,
sequence_length
);
}
}
size_t
GetLongformerAttentionWorkspaceSize
(
size_t
element_size
,
size_t
batch_size
,
size_t
num_heads
,
size_t
head_size
,
size_t
sequence_length
,
size_t
max_num_global
,
size_t
window
,
bool
disable_compact_memory
)
{
size_t
softmax_size
=
GetLongformerSoftmaxWorkspaceSize
(
element_size
,
batch_size
,
num_heads
,
sequence_length
,
window
,
disable_compact_memory
);
size_t
qkv_size
=
static_cast
<
size_t
>
(
3
)
*
batch_size
*
sequence_length
*
num_heads
*
head_size
*
element_size
;
size_t
global_qkv_size
=
max_num_global
>
0
?
qkv_size
:
0
;
return
softmax_size
+
qkv_size
+
global_qkv_size
;
}
// Size of buffer of pinned memory in CPU. The buffer is used to copy memory between CPU and GPU.
// The buffer includes two parts: [global_count (copy of batch_global_num): int Bx1] [copy of scratch2]
size_t
GetPinnedBufferSize
(
size_t
batch_size
)
{
return
sizeof
(
int
)
*
batch_size
+
GetScratch2Size
();
}
// Softmax kernel for compact format
template
<
typename
T
,
int
blockSize
>
__launch_bounds__
(
blockSize
)
__global__
void
LongformerSoftmaxKernel
(
const
int
*
global_attention
,
const
int
*
global_index
,
const
int
*
batch_global_num
,
void
*
buffer_pointers
,
const
T
*
attention_mask
,
float
scaler
,
int
sequence_length
,
int
num_heads
,
int
window
)
{
typedef
hipcub
::
BlockReduce
<
float
,
blockSize
>
BlockReduce
;
__shared__
typename
BlockReduce
::
TempStorage
block_reduce_temp
;
int
tid
=
threadIdx
.
x
;
const
int
batch_index
=
blockIdx
.
x
/
(
sequence_length
*
num_heads
);
const
int
row_index
=
blockIdx
.
x
%
sequence_length
;
const
int
head_index
=
(
blockIdx
.
x
/
sequence_length
)
%
num_heads
;
// Adjust the pointers for the batch
const
T
*
mask_block
=
attention_mask
+
sequence_length
*
batch_index
;
const
int
*
global_index_block
=
global_index
+
sequence_length
*
batch_index
;
const
int
global_num
=
batch_global_num
[
batch_index
];
size_t
*
p_inputs
=
reinterpret_cast
<
size_t
*>
(
buffer_pointers
);
size_t
*
p_outputs
=
reinterpret_cast
<
size_t
*>
(
buffer_pointers
);
size_t
*
input_sizes
=
reinterpret_cast
<
size_t
*>
(
buffer_pointers
)
+
5
;
size_t
*
input_strides
=
reinterpret_cast
<
size_t
*>
(
buffer_pointers
)
+
10
;
const
T
*
inputs
[
5
];
T
*
outputs
[
5
];
for
(
int
i
=
0
;
i
<
5
;
++
i
)
{
inputs
[
i
]
=
reinterpret_cast
<
T
*>
(
p_inputs
[
i
])
+
batch_index
*
num_heads
*
input_sizes
[
i
];
outputs
[
i
]
=
reinterpret_cast
<
T
*>
(
p_outputs
[
i
])
+
batch_index
*
num_heads
*
input_sizes
[
i
];
}
// Local attention token
int
col_start
=
0
;
int
col_end
=
sequence_length
;
bool
is_local_row
=
(
global_attention
[
batch_index
*
sequence_length
+
row_index
]
==
static_cast
<
int
>
(
0
));
if
(
is_local_row
)
{
col_start
=
row_index
-
window
;
if
(
col_start
<
0
)
{
col_start
=
0
;
}
col_end
=
row_index
+
window
+
1
;
if
(
col_end
>
sequence_length
)
{
col_end
=
sequence_length
;
}
}
// If mask is set then set everything to zero to match huggingface transformers implementation
if
((
float
)
mask_block
[
row_index
]
!=
0.
f
)
{
if
(
is_local_row
)
{
T
*
output_block
=
nullptr
;
T
*
output_global
=
nullptr
;
int
local_offset
=
row_index
%
window
;
int
local_start
=
0
;
int
local_end
=
3
*
window
;
if
(
row_index
<
window
)
{
local_start
=
0
;
local_end
=
2
*
window
;
output_block
=
outputs
[
0
]
+
row_index
*
input_strides
[
0
]
+
head_index
*
input_sizes
[
0
];
}
else
if
(
row_index
<
sequence_length
-
window
)
{
output_block
=
outputs
[
1
]
+
(
row_index
-
window
)
*
input_strides
[
1
]
+
head_index
*
input_sizes
[
1
];
}
else
{
local_start
=
0
;
local_end
=
2
*
window
;
output_block
=
outputs
[
2
]
+
local_offset
*
input_strides
[
2
]
+
head_index
*
input_sizes
[
2
];
}
for
(
int
i
=
local_start
+
tid
;
i
<
local_end
;
i
+=
blockSize
)
{
output_block
[
i
]
=
0
;
}
if
((
row_index
-
2
*
window
)
>=
0
)
{
output_global
=
outputs
[
3
]
+
(
row_index
-
window
)
*
input_strides
[
3
]
+
head_index
*
input_sizes
[
3
];
}
if
(
output_global
!=
nullptr
)
{
for
(
int
i
=
tid
;
i
<
global_num
;
i
+=
blockSize
)
{
output_global
[
i
]
=
0
;
}
}
}
else
{
T
*
output_block
=
outputs
[
4
];
for
(
int
i
=
tid
;
i
<
sequence_length
;
i
+=
blockSize
)
output_block
[
i
]
=
0
;
}
return
;
}
float
sum_input
=
0.
;
__shared__
float
sum_shared
;
// Calculate max input
float
max_input
=
-
std
::
numeric_limits
<
float
>::
infinity
();
__shared__
float
max_shared
;
if
(
is_local_row
)
{
const
T
*
input_block
=
nullptr
;
T
*
output_block
=
nullptr
;
T
*
output_global
=
nullptr
;
int
local_offset
=
row_index
%
window
;
int
local_start
=
local_offset
;
int
local_end
=
local_start
+
2
*
window
+
1
;
int
zero_start
=
0
;
int
zero_end
=
3
*
window
;
if
(
row_index
<
window
)
{
local_start
=
0
;
local_end
=
local_offset
+
window
+
1
;
zero_end
=
2
*
window
;
input_block
=
inputs
[
0
]
+
row_index
*
input_strides
[
0
]
+
head_index
*
input_sizes
[
0
];
output_block
=
outputs
[
0
]
+
row_index
*
input_strides
[
0
]
+
head_index
*
input_sizes
[
0
];
}
else
if
(
row_index
<
sequence_length
-
window
)
{
input_block
=
inputs
[
1
]
+
(
row_index
-
window
)
*
input_strides
[
1
]
+
head_index
*
input_sizes
[
1
];
output_block
=
outputs
[
1
]
+
(
row_index
-
window
)
*
input_strides
[
1
]
+
head_index
*
input_sizes
[
1
];
}
else
{
local_start
=
local_offset
;
local_end
=
2
*
window
;
zero_end
=
2
*
window
;
input_block
=
inputs
[
2
]
+
local_offset
*
input_strides
[
2
]
+
head_index
*
input_sizes
[
2
];
output_block
=
outputs
[
2
]
+
local_offset
*
input_strides
[
2
]
+
head_index
*
input_sizes
[
2
];
}
const
T
*
input_global
=
nullptr
;
int
local_global
=
row_index
-
window
;
if
(
local_global
>
global_num
)
{
local_global
=
global_num
;
}
if
(
local_global
>
0
)
{
input_global
=
inputs
[
3
]
+
(
row_index
-
window
)
*
input_strides
[
3
]
+
head_index
*
input_sizes
[
3
];
}
if
(
row_index
<
window
)
{
output_global
=
(
T
*
)
outputs
[
0
]
+
row_index
*
input_strides
[
0
]
+
head_index
*
input_sizes
[
0
];
}
else
if
(
row_index
<
2
*
window
)
{
output_global
=
outputs
[
1
]
+
(
row_index
-
window
)
*
input_strides
[
1
]
+
head_index
*
input_sizes
[
1
];
}
else
{
output_global
=
outputs
[
3
]
+
(
row_index
-
window
)
*
input_strides
[
3
]
+
head_index
*
input_sizes
[
3
];
}
for
(
int
i
=
local_start
+
tid
,
j
=
col_start
+
tid
;
i
<
local_end
;
i
+=
blockSize
,
j
+=
blockSize
)
{
float
x
=
input_block
[
i
];
x
=
x
*
scaler
+
(
float
)
mask_block
[
j
];
if
(
max_input
<
x
)
max_input
=
x
;
}
if
(
input_global
!=
nullptr
)
{
for
(
int
i
=
tid
;
i
<
local_global
;
i
+=
blockSize
)
{
float
x
=
input_global
[
global_index_block
[
i
]];
x
=
x
*
scaler
+
(
float
)
mask_block
[
global_index_block
[
i
]];
if
(
max_input
<
x
)
max_input
=
x
;
}
}
float
max_block
=
BlockReduce
(
block_reduce_temp
).
Reduce
(
max_input
,
hipcub
::
Max
());
if
(
tid
==
0
)
{
max_shared
=
max_block
;
}
__syncthreads
();
for
(
int
i
=
local_start
+
tid
,
j
=
col_start
+
tid
;
i
<
local_end
;
i
+=
blockSize
,
j
+=
blockSize
)
{
float
x
=
input_block
[
i
];
x
=
expf
((
x
)
*
scaler
+
(
float
)
mask_block
[
j
]
-
max_shared
);
sum_input
+=
x
;
}
if
(
input_global
!=
nullptr
)
{
for
(
int
i
=
tid
,
j
=
col_start
+
tid
;
i
<
local_global
;
i
+=
blockSize
,
j
+=
blockSize
)
{
float
x
=
input_global
[
global_index_block
[
i
]];
x
=
expf
((
x
)
*
scaler
+
(
float
)
mask_block
[
j
]
-
max_shared
);
sum_input
+=
x
;
}
}
float
sum_block
=
BlockReduce
(
block_reduce_temp
).
Reduce
(
sum_input
,
hipcub
::
Sum
());
if
(
tid
==
0
)
{
sum_shared
=
sum_block
;
}
__syncthreads
();
float
recip_sum
=
1.
f
/
sum_shared
;
for
(
int
i
=
tid
+
zero_start
;
i
<
local_start
;
i
+=
blockSize
)
{
output_block
[
i
]
=
(
T
)(
0.
);
}
for
(
int
i
=
tid
+
local_end
;
i
<
zero_end
;
i
+=
blockSize
)
{
output_block
[
i
]
=
(
T
)(
0.
);
}
__syncthreads
();
for
(
int
i
=
local_start
+
tid
,
j
=
col_start
+
tid
;
i
<
local_end
;
i
+=
blockSize
,
j
+=
blockSize
)
{
float
x
=
input_block
[
i
];
x
=
expf
((
x
)
*
scaler
+
(
float
)
mask_block
[
j
]
-
max_shared
);
output_block
[
i
]
=
(
T
)(
recip_sum
*
x
);
}
if
(
input_global
!=
nullptr
)
{
for
(
int
i
=
tid
;
i
<
local_global
;
i
+=
blockSize
)
{
float
x
=
input_global
[
global_index_block
[
i
]];
x
=
expf
((
x
)
*
scaler
+
(
float
)
mask_block
[
global_index_block
[
i
]]
-
max_shared
);
output_global
[
i
]
=
(
T
)(
recip_sum
*
x
);
}
}
}
else
{
// Global tokens
const
T
*
input_block
=
inputs
[
4
]
+
row_index
*
input_strides
[
4
]
+
head_index
*
input_sizes
[
4
];
T
*
output_block
=
outputs
[
4
]
+
row_index
*
input_strides
[
4
]
+
head_index
*
input_sizes
[
4
];
for
(
int
i
=
tid
;
i
<
sequence_length
;
i
+=
blockSize
)
{
float
x
=
input_block
[
i
];
x
=
x
*
scaler
+
(
float
)
mask_block
[
i
];
if
(
max_input
<
x
)
max_input
=
x
;
}
float
max_block
=
BlockReduce
(
block_reduce_temp
).
Reduce
(
max_input
,
hipcub
::
Max
());
if
(
tid
==
0
)
{
max_shared
=
max_block
;
}
__syncthreads
();
for
(
int
i
=
tid
;
i
<
sequence_length
;
i
+=
blockSize
)
{
float
x
=
input_block
[
i
];
x
=
expf
((
x
)
*
scaler
+
(
float
)
mask_block
[
i
]
-
max_shared
);
sum_input
+=
x
;
}
float
sum_block
=
BlockReduce
(
block_reduce_temp
).
Reduce
(
sum_input
,
hipcub
::
Sum
());
if
(
tid
==
0
)
{
sum_shared
=
sum_block
;
}
__syncthreads
();
float
recip_sum
=
1.
f
/
sum_shared
;
for
(
int
i
=
tid
;
i
<
sequence_length
;
i
+=
blockSize
)
{
float
x
=
input_block
[
i
];
x
=
expf
((
x
)
*
scaler
+
(
float
)
mask_block
[
i
]
-
max_shared
);
output_block
[
i
]
=
(
T
)(
recip_sum
*
x
);
}
}
}
Status
LaunchLongformerSoftmaxKernel
(
hipStream_t
stream
,
rocblas_handle
rocblas
,
void
*
workspace
,
const
void
*
q
,
// transposed Q with shape (B, N, S, H)
const
void
*
k
,
// transposed K with shape (B, N, S, H)
const
void
*
v
,
// transposed V with shape (B, N, S, H)
const
void
*
attention_mask
,
// attention mask with shape (B, S), with value 0 not masked and -10000 masked.
int
max_num_global
,
// maximum number of global tokens (G)
const
bool
compact_global_q
,
// whether global_q has shape (B, N, G, H) instead of (B, N, S, H)
const
void
*
global_q
,
// Q for global tokens with shape (B, N, S, H).
const
void
*
global_k
,
// K for global tokens with shape (B, N, S, H)
const
void
*
global_v
,
// V for global tokens with shape (B, N, S, H)
const
int
*
global_attention
,
// global attention flags with shape (B, S), with value 0 for local and 1 for global.
const
int
*
global_index
,
// Global index with shape (B, S)
const
int
*
batch_global_num
,
// Number of global tokens per batch with shape (B, 1)
void
*
pinned_buffer
,
// Pinned memory in CPU with 2 parts: global tokens per batch, and data for scratch2
void
*
output
,
// output with shape (B, N, S, H)
float
scaler
,
// scalar
int
batch_size
,
// batch size
int
sequence_length
,
// sequence length
int
num_heads
,
// number of heads
int
head_size
,
// hidden size per head
int
window
,
// one sided window size
size_t
element_size
)
{
// size of element: 2 for half, and 4 for float
const
int
*
global_count
=
reinterpret_cast
<
const
int
*>
(
pinned_buffer
);
bool
is_fp16
=
(
element_size
==
2
);
char
*
scratch1
=
reinterpret_cast
<
char
*>
(
workspace
);
char
*
scratch2
=
scratch1
+
GetScratch1Size
(
element_size
,
batch_size
,
num_heads
,
sequence_length
,
window
);
// Setup shared parameters for two strided batched matrix multiplies
rocblas_datatype
Atype
;
rocblas_datatype
Btype
;
rocblas_datatype
Ctype
;
rocblas_datatype
resultType
;
rocblas_gemm_algo
algo
=
rocblas_gemm_algo_standard
;
__half
one_fp16
,
zero_fp16
;
float
one_fp32
,
zero_fp32
;
void
*
alpha
,
*
beta_0
,
*
beta_1
;
if
(
is_fp16
)
{
one_fp16
=
__float2half
(
1.
f
);
zero_fp16
=
__float2half
(
0.
f
);
alpha
=
static_cast
<
void
*>
(
&
one_fp16
);
beta_0
=
static_cast
<
void
*>
(
&
zero_fp16
);
beta_1
=
static_cast
<
void
*>
(
&
one_fp16
);
Atype
=
rocblas_datatype_f16_r
;
Btype
=
rocblas_datatype_f16_r
;
Ctype
=
rocblas_datatype_f16_r
;
resultType
=
rocblas_datatype_f16_r
;
algo
=
rocblas_gemm_algo_standard
;
}
else
{
one_fp32
=
1.
f
;
zero_fp32
=
0.
f
;
alpha
=
static_cast
<
void
*>
(
&
one_fp32
);
beta_0
=
static_cast
<
void
*>
(
&
zero_fp32
);
beta_1
=
static_cast
<
void
*>
(
&
one_fp32
);
Atype
=
rocblas_datatype_f32_r
;
Btype
=
rocblas_datatype_f32_r
;
Ctype
=
rocblas_datatype_f32_r
;
resultType
=
rocblas_datatype_f32_r
;
}
// Strided batch matrix multiply
// qk = q * k^T
// Shapes: q and k = B x N x S x H, qk = B x N x S x S
// Convert col-major to row-major by swapping q and k in Gemm
size_t
elements_per_batch
=
num_heads
*
sequence_length
*
head_size
;
int
stride_per_head
=
sequence_length
*
head_size
;
// stride for Q, K, V and output
// Local attention part
// S x S is calculated using sliding block WxW (W is one sided window size) like the following:
// [W][W]
// [W][W][W]
// [W][W][W]
// [W][W]
// The first and last rows have 2 blocks per row, and the remaining has 3 blocks per row.
// The calculation are splited into 3 parts: the first row, middle rows and finally the last row.
// To save space, we do not store the whole matrix. Instead, we only allocate space for these blocks.
//
// For global attention part, we have two assumptions:
// (1) Global tokens are at the beginging of sequence
// (2) Number of global tokens <= attention window
//
// The results are stored in scratch1 buffer:
// Number of elements for local attention are (3*S/W-2)*W*W*N*B, or (3S-2W)*W*N*B
// Number of elements for local attends to global are (S-W)*W*N*B
// Number of elements for global attends to everything are S*W*N*B
// Total elements (FP16 or FP32) are (5S-3W)*W*N*B
const
int
w
=
window
;
const
int
middle_count
=
(
sequence_length
-
2
*
w
)
/
w
;
int
last_block
=
(
sequence_length
/
w
)
-
1
;
// Determine the non-zero block dimensions and pointers
// Buffer size per head for a single batch
size_t
buffer_sizes
[
5
]
=
{
static_cast
<
size_t
>
(
w
*
w
*
2
),
// first row of blocks has 2 WxW blocks
static_cast
<
size_t
>
(
w
*
w
*
middle_count
*
3
),
// middle rows of blocks have 3 WxW blocks per row
static_cast
<
size_t
>
(
w
*
w
*
2
),
// last row of blocks has 2 WxW blocks
static_cast
<
size_t
>
(
w
*
(
sequence_length
-
w
)),
// local attends to global: global tokens <= window size
static_cast
<
size_t
>
(
w
*
sequence_length
)};
// global attends to everything.
size_t
buffer_strides
[
5
]
=
{
static_cast
<
size_t
>
(
w
*
2
),
static_cast
<
size_t
>
(
w
*
3
),
static_cast
<
size_t
>
(
w
*
2
),
static_cast
<
size_t
>
(
w
),
// number of global tokens <= window size
static_cast
<
size_t
>
(
sequence_length
)};
void
*
buffer_pointers
[
5
];
char
*
current_pointer
=
scratch1
;
for
(
int
i
=
0
;
i
<
5
;
++
i
)
{
buffer_pointers
[
i
]
=
reinterpret_cast
<
void
*>
(
current_pointer
);
current_pointer
+=
buffer_sizes
[
i
]
*
num_heads
*
batch_size
*
element_size
;
}
// Copy to a continues buffer first so that we only need call hipMemcpyAsync once
char
*
temp_buffer
=
reinterpret_cast
<
char
*>
(
pinned_buffer
)
+
sizeof
(
int
)
*
batch_size
;
memcpy
(
temp_buffer
,
&
buffer_pointers
[
0
],
5
*
sizeof
(
void
*
));
memcpy
(
temp_buffer
+
5
*
sizeof
(
void
*
),
&
buffer_sizes
[
0
],
5
*
sizeof
(
size_t
));
memcpy
(
temp_buffer
+
5
*
sizeof
(
void
*
)
+
5
*
sizeof
(
size_t
),
&
buffer_strides
[
0
],
5
*
sizeof
(
size_t
));
CHECK_ROCM
(
hipMemcpyAsync
(
scratch2
,
temp_buffer
,
GetScratch2Size
(),
hipMemcpyHostToDevice
,
stream
));
// Local attention part
{
// local attention per head - head
CHECK
(
_compat_rocblas_gemm_strided_batched_ex
(
rocblas
,
rocblas_operation_transpose
,
rocblas_operation_none
,
2
*
w
,
// m
w
,
// n
head_size
,
// k
alpha
,
// alpha
k
,
// A
Atype
,
// A type
head_size
,
// lda
stride_per_head
,
// strideA
q
,
// B
Btype
,
// B type
head_size
,
// ldb
stride_per_head
,
// strideB
beta_0
,
// beta
buffer_pointers
[
0
],
// C
Ctype
,
// C type
2
*
w
,
// ldc
buffer_sizes
[
0
],
// strideC
batch_size
*
num_heads
,
// batch count
resultType
,
algo
));
// local attention per head - middle
if
(
middle_count
>
0
)
{
for
(
int
i
=
0
;
i
<
batch_size
;
++
i
)
{
for
(
int
j
=
0
;
j
<
num_heads
;
++
j
)
{
const
void
*
q_head
=
reinterpret_cast
<
const
char
*>
(
q
)
+
(
i
*
elements_per_batch
+
(
j
*
sequence_length
+
w
)
*
head_size
)
*
element_size
;
const
void
*
k_head
=
reinterpret_cast
<
const
char
*>
(
k
)
+
(
i
*
elements_per_batch
+
j
*
sequence_length
*
head_size
)
*
element_size
;
void
*
qk_head
=
reinterpret_cast
<
char
*>
(
buffer_pointers
[
1
])
+
static_cast
<
size_t
>
(
i
*
num_heads
+
j
)
*
buffer_sizes
[
1
]
*
element_size
;
CHECK
(
_compat_rocblas_gemm_strided_batched_ex
(
rocblas
,
rocblas_operation_transpose
,
rocblas_operation_none
,
3
*
w
,
// m
w
,
// n
head_size
,
// k
alpha
,
// alpha
k_head
,
// A
Atype
,
// A type
head_size
,
// lda
w
*
head_size
,
// strideA
q_head
,
// B
Btype
,
// B type
head_size
,
// ldb
w
*
head_size
,
// strideB
beta_0
,
// beta
qk_head
,
// C
Ctype
,
// C type
3
*
w
,
// ldc
3
*
w
*
w
,
// strideC
middle_count
,
// batch count
resultType
,
algo
));
}
}
}
// local attention per head - tail
const
void
*
q_head
=
reinterpret_cast
<
const
char
*>
(
q
)
+
(
last_block
*
w
*
head_size
)
*
element_size
;
const
void
*
k_head
=
reinterpret_cast
<
const
char
*>
(
k
)
+
((
last_block
-
1
)
*
w
*
head_size
)
*
element_size
;
CHECK
(
_compat_rocblas_gemm_strided_batched_ex
(
rocblas
,
rocblas_operation_transpose
,
rocblas_operation_none
,
2
*
w
,
// m
w
,
// n
head_size
,
// k
alpha
,
// alpha
k_head
,
// A
Atype
,
// A type
head_size
,
// lda
stride_per_head
,
// strideA
q_head
,
// B
Btype
,
// B type
head_size
,
// ldb
stride_per_head
,
// strideB
beta_0
,
// beta
buffer_pointers
[
2
],
// C
Ctype
,
// C type
2
*
w
,
// ldc
buffer_sizes
[
2
],
// strideC
batch_size
*
num_heads
,
// batch count
resultType
,
algo
));
}
// Global attention part
for
(
int
i
=
0
;
i
<
batch_size
;
++
i
)
{
if
(
global_count
[
i
]
>
0
)
{
const
void
*
q_batch
=
reinterpret_cast
<
const
char
*>
(
q
)
+
(
i
*
elements_per_batch
+
w
*
head_size
)
*
element_size
;
const
void
*
k_batch
=
reinterpret_cast
<
const
char
*>
(
k
)
+
(
i
*
elements_per_batch
)
*
element_size
;
void
*
qk_batch
=
reinterpret_cast
<
char
*>
(
buffer_pointers
[
3
])
+
(
i
*
buffer_sizes
[
3
])
*
num_heads
*
element_size
;
// Local tokens attending global tokens
CHECK
(
_compat_rocblas_gemm_strided_batched_ex
(
rocblas
,
rocblas_operation_transpose
,
rocblas_operation_none
,
global_count
[
i
],
// m
sequence_length
-
w
,
// n
head_size
,
// k
alpha
,
// alpha
k_batch
,
// A
Atype
,
// A type
head_size
,
// lda
stride_per_head
,
// strideA
q_batch
,
// B
Btype
,
// B type
head_size
,
// ldb
stride_per_head
,
// strideB
beta_0
,
// beta
qk_batch
,
// C
Ctype
,
// C type
w
,
// ldc
buffer_sizes
[
3
],
// strideC
num_heads
,
// batch count
resultType
,
algo
));
const
size_t
global_q_per_batch
=
compact_global_q
?
num_heads
*
max_num_global
*
head_size
:
elements_per_batch
;
const
int
global_q_stride
=
(
compact_global_q
?
max_num_global
*
head_size
:
stride_per_head
);
const
void
*
global_q_batch
=
reinterpret_cast
<
const
char
*>
(
global_q
)
+
(
i
*
global_q_per_batch
)
*
element_size
;
const
void
*
global_k_batch
=
reinterpret_cast
<
const
char
*>
(
global_k
)
+
(
i
*
elements_per_batch
)
*
element_size
;
qk_batch
=
reinterpret_cast
<
char
*>
(
buffer_pointers
[
4
])
+
(
i
*
buffer_sizes
[
4
]
*
num_heads
)
*
element_size
;
// Global tokens attending everything
// This GEMMs need to be last to make sure all global token entries are re-written.
CHECK
(
_compat_rocblas_gemm_strided_batched_ex
(
rocblas
,
rocblas_operation_transpose
,
rocblas_operation_none
,
sequence_length
,
// m
global_count
[
i
],
// n
head_size
,
// k
alpha
,
// alpha
global_k_batch
,
// A
Atype
,
// A type
head_size
,
// lda
stride_per_head
,
// strideA
global_q_batch
,
// B
Btype
,
// B type
head_size
,
// ldb
global_q_stride
,
// strideB.
beta_0
,
// beta
qk_batch
,
// C
Ctype
,
// C type
sequence_length
,
// ldc
buffer_sizes
[
4
],
// strideC
num_heads
,
// batch count
resultType
,
algo
));
}
}
const
int
blockSize
=
64
;
const
int
gridSize
=
batch_size
*
num_heads
*
sequence_length
;
if
(
is_fp16
)
{
hipLaunchKernelGGL
(
HIP_KERNEL_NAME
(
LongformerSoftmaxKernel
<
__half
,
blockSize
>
),
gridSize
,
blockSize
,
0
,
stream
,
global_attention
,
global_index
,
batch_global_num
,
scratch2
,
static_cast
<
const
__half
*>
(
attention_mask
),
scaler
,
sequence_length
,
num_heads
,
window
);
}
else
{
hipLaunchKernelGGL
(
HIP_KERNEL_NAME
(
LongformerSoftmaxKernel
<
float
,
blockSize
>
),
gridSize
,
blockSize
,
0
,
stream
,
global_attention
,
global_index
,
batch_global_num
,
scratch2
,
static_cast
<
const
float
*>
(
attention_mask
),
scaler
,
sequence_length
,
num_heads
,
window
);
}
// local values attending the softmax score.
{
// local attention per head - head
CHECK
(
_compat_rocblas_gemm_strided_batched_ex
(
rocblas
,
rocblas_operation_none
,
rocblas_operation_none
,
head_size
,
// m
w
,
// n
2
*
w
,
// k
alpha
,
// alpha
v
,
// A
Atype
,
// A type
head_size
,
// lda
stride_per_head
,
// strideA
buffer_pointers
[
0
],
// B
Btype
,
// B type
static_cast
<
int
>
(
buffer_strides
[
0
]),
// ldb
buffer_sizes
[
0
],
// strideB
beta_0
,
// beta
output
,
// C
Ctype
,
// C type
head_size
,
// ldc
stride_per_head
,
// strideC
batch_size
*
num_heads
,
// batch count
resultType
,
algo
));
// local attention per head - middle
if
(
middle_count
>
0
)
{
for
(
int
i
=
0
;
i
<
batch_size
;
++
i
)
{
for
(
int
j
=
0
;
j
<
num_heads
;
++
j
)
{
const
void
*
v_head
=
reinterpret_cast
<
const
char
*>
(
v
)
+
(
i
*
elements_per_batch
+
j
*
head_size
*
sequence_length
)
*
element_size
;
const
void
*
prob_head
=
reinterpret_cast
<
const
char
*>
(
buffer_pointers
[
1
])
+
(
i
*
num_heads
+
j
)
*
buffer_sizes
[
1
]
*
element_size
;
void
*
out_head
=
reinterpret_cast
<
char
*>
(
output
)
+
(
i
*
elements_per_batch
+
j
*
head_size
*
sequence_length
+
w
*
head_size
)
*
element_size
;
CHECK
(
_compat_rocblas_gemm_strided_batched_ex
(
rocblas
,
rocblas_operation_none
,
rocblas_operation_none
,
head_size
,
// m
w
,
// n
3
*
w
,
// k
alpha
,
// alpha
v_head
,
// A
Atype
,
// A type
head_size
,
// lda
w
*
head_size
,
// strideA
prob_head
,
// B
Btype
,
// B type
static_cast
<
int
>
(
buffer_strides
[
1
]),
// ldb
3
*
w
*
w
,
// strideB
beta_0
,
// beta
out_head
,
// C
Ctype
,
// C type
head_size
,
// ldc
w
*
head_size
,
// strideC
middle_count
,
// batch count
resultType
,
algo
));
}
}
}
// local attention per head - tail
const
void
*
v_head
=
reinterpret_cast
<
const
char
*>
(
v
)
+
(
last_block
-
1
)
*
w
*
head_size
*
element_size
;
void
*
out_head
=
reinterpret_cast
<
char
*>
(
output
)
+
last_block
*
w
*
head_size
*
element_size
;
CHECK
(
_compat_rocblas_gemm_strided_batched_ex
(
rocblas
,
rocblas_operation_none
,
rocblas_operation_none
,
head_size
,
// m
w
,
// n
2
*
w
,
// k
alpha
,
// alpha
v_head
,
// A
Atype
,
// A type
head_size
,
// lda
stride_per_head
,
// strideA
buffer_pointers
[
2
],
// B
Btype
,
// B type
static_cast
<
int
>
(
buffer_strides
[
2
]),
// ldb
buffer_sizes
[
2
],
// strideB
beta_0
,
// beta
out_head
,
// C
Ctype
,
// C type
head_size
,
// ldc
stride_per_head
,
// strideC
batch_size
*
num_heads
,
// batch count
resultType
,
algo
));
}
// global attention part
for
(
int
i
=
0
;
i
<
batch_size
;
++
i
)
{
if
(
global_count
[
i
]
>
0
)
{
// Local tokens attending global tokens
const
void
*
v_head
=
reinterpret_cast
<
const
char
*>
(
v
)
+
(
i
*
elements_per_batch
)
*
element_size
;
const
void
*
prob_head
=
reinterpret_cast
<
const
char
*>
(
buffer_pointers
[
3
])
+
(
i
*
buffer_sizes
[
3
]
*
num_heads
+
w
*
buffer_strides
[
3
])
*
element_size
;
void
*
out_head
=
reinterpret_cast
<
char
*>
(
output
)
+
(
i
*
elements_per_batch
+
2
*
w
*
head_size
)
*
element_size
;
CHECK
(
_compat_rocblas_gemm_strided_batched_ex
(
rocblas
,
rocblas_operation_none
,
rocblas_operation_none
,
head_size
,
// m
sequence_length
-
2
*
w
,
// n
global_count
[
i
],
// k
alpha
,
// alpha
v_head
,
// A
Atype
,
// A type
head_size
,
// lda
stride_per_head
,
// strideA
prob_head
,
// B
Btype
,
// B type
static_cast
<
int
>
(
buffer_strides
[
3
]),
// ldb
buffer_sizes
[
3
],
// strideB
beta_1
,
// beta
out_head
,
// C
Ctype
,
// C type
head_size
,
// ldc
stride_per_head
,
// strideC
num_heads
,
// batch count
resultType
,
algo
));
// Global tokens attending everything
v_head
=
reinterpret_cast
<
const
char
*>
(
global_v
)
+
(
i
*
elements_per_batch
)
*
element_size
;
prob_head
=
reinterpret_cast
<
const
char
*>
(
buffer_pointers
[
4
])
+
(
i
*
buffer_sizes
[
4
]
*
num_heads
)
*
element_size
;
out_head
=
reinterpret_cast
<
char
*>
(
output
)
+
(
i
*
elements_per_batch
)
*
element_size
;
CHECK
(
_compat_rocblas_gemm_strided_batched_ex
(
rocblas
,
rocblas_operation_none
,
rocblas_operation_none
,
head_size
,
// m
global_count
[
i
],
// n
sequence_length
,
// k: re-write entries completely
alpha
,
// alpha
v_head
,
// A
Atype
,
// A type
head_size
,
// lda
stride_per_head
,
// strideA
prob_head
,
// B
Btype
,
// B type
static_cast
<
int
>
(
buffer_strides
[
4
]),
// ldb
buffer_sizes
[
4
],
// strideB
beta_0
,
// beta: overwrite
out_head
,
// C: assumes global tokens at the beginning of sequence
Ctype
,
// C type
head_size
,
// ldc
stride_per_head
,
// strideC
num_heads
,
// batch count
resultType
,
algo
));
}
}
return
Status
::
OK
();
}
template
<
typename
T
>
Status
LongformerQkvToContext
(
const
hipDeviceProp_t
&
device_prop
,
rocblas_handle
rocblas
,
hipStream_t
stream
,
const
int
batch_size
,
// batch size
const
int
sequence_length
,
// sequence length
const
int
num_heads
,
// number of attention heads
const
int
head_size
,
// hidden size per head
const
int
window
,
// Half (one-sided) window size
const
size_t
element_size
,
const
T
*
input
,
// input for transpose
const
T
*
bias
,
// bias to add to transposed input
const
T
*
attention_mask
,
// attention mask with shape (B, S), with value 0.0 not masked, and -10000.0 masked.
const
T
*
global_input
,
// global input for transpose
const
T
*
global_bias
,
// bias to add to transposed global input
const
int
*
global_attention
,
// global attention flags with shape (B, S), with value 0 for local and 1 for global.
const
int
*
global_index
,
// Global index with shape (B, S)
const
int
*
batch_global_num
,
// Number of global tokens per batch with shape (B, 1)
const
int
max_num_global
,
// Maximum number of global tokens (G)
void
*
pinned_buffer
,
// Pinned memory in CPU. Number of global tokens per batch with shape (B, 1)
T
*
workspace
,
// Softmax space
T
*
output
,
// output
size_t
softmax_workspace_size
,
bool
disable_compact_memory
,
bool
use_merged_qkv_weights
,
bool
use_half4
)
{
T
*
qkv
=
reinterpret_cast
<
T
*>
(
reinterpret_cast
<
char
*>
(
workspace
)
+
softmax_workspace_size
);
// Number of elements in Q, K, V, Global_Q, Global_K or Global_V are same: BxNxSxH
const
int
elements
=
batch_size
*
num_heads
*
sequence_length
*
head_size
;
const
int
max_threads_per_block
(
device_prop
.
maxThreadsPerBlock
);
const
int
format
=
static_cast
<
int
>
(
use_merged_qkv_weights
);
bool
compact_global_q
=
false
;
// The order of qkv space:
// Q, K, V, Global_K, Global_V, Global_Q (format 0)
// Q, K, V, Global_Q, Global_K, Global_V (format 1)
// Assume Q, K and V has same hidden size
if
(
format
==
1
||
max_num_global
==
0
||
nullptr
==
global_input
)
{
if
(
bias
==
nullptr
)
{
ORT_RETURN_IF_ERROR
(
LaunchTransQkv
(
stream
,
3
,
sequence_length
,
batch_size
,
head_size
,
num_heads
,
max_threads_per_block
,
false
,
input
,
qkv
));
}
else
{
LaunchAddBiasTranspose
(
stream
,
3
,
format
,
max_threads_per_block
,
batch_size
,
sequence_length
,
num_heads
,
head_size
,
input
,
bias
,
qkv
,
use_half4
,
head_size
);
}
if
(
max_num_global
>
0
&&
nullptr
!=
global_input
)
{
if
(
global_bias
==
nullptr
)
{
ORT_RETURN_IF_ERROR
(
LaunchTransQkv
(
stream
,
3
,
sequence_length
,
batch_size
,
head_size
,
num_heads
,
max_threads_per_block
,
false
,
global_input
,
qkv
+
3
*
elements
));
}
else
{
LaunchAddBiasTranspose
(
stream
,
3
,
format
,
max_threads_per_block
,
batch_size
,
sequence_length
,
num_heads
,
head_size
,
global_input
,
global_bias
,
qkv
+
3
*
elements
,
use_half4
,
head_size
);
}
}
}
else
{
LaunchAddBiasTranspose
(
stream
,
5
,
format
,
max_threads_per_block
,
batch_size
,
sequence_length
,
num_heads
,
head_size
,
input
,
bias
,
qkv
,
use_half4
,
head_size
);
compact_global_q
=
(
disable_compact_memory
==
false
);
LaunchAddBiasTranspose
(
stream
,
1
,
format
,
max_threads_per_block
,
batch_size
,
compact_global_q
?
max_num_global
:
sequence_length
,
num_heads
,
head_size
,
global_input
+
2
*
elements
,
global_bias
,
qkv
+
5
*
elements
,
use_half4
,
head_size
);
}
HIP_RETURN_IF_ERROR
(
hipGetLastError
());
// Transposed Q, K, V with shape (B, N, S, H)
const
T
*
q
=
qkv
;
const
T
*
k
=
q
+
elements
;
const
T
*
v
=
k
+
elements
;
// Transposed global Q, K, V with shape (B, N, S, H).
// When compact_global_q is true, Global Q has actual shape (B, N, G, H) although we allocated space of (B, N, S, H)
// When max_num_global == 0, these pointers are not used in GEMM so the value does not matter.
const
T
*
global_q
=
(
format
==
1
?
v
+
elements
:
qkv
+
5
*
elements
);
const
T
*
global_k
=
(
format
==
1
?
global_q
+
elements
:
qkv
+
3
*
elements
);
const
T
*
global_v
=
(
format
==
1
?
global_k
+
elements
:
qkv
+
4
*
elements
);
// Q*K' are scaled by 1/sqrt(H)
const
float
rsqrt_head_size
=
1.
f
/
sqrt
(
static_cast
<
float
>
(
head_size
));
T
*
temp_output
=
qkv
;
// Q will be overwritten
if
(
disable_compact_memory
)
{
ORT_RETURN_IF_ERROR
(
LaunchLongformerSoftmaxSimpleKernel
(
stream
,
rocblas
,
workspace
,
q
,
k
,
v
,
attention_mask
,
global_q
,
global_k
,
global_v
,
global_attention
,
global_index
,
batch_global_num
,
pinned_buffer
,
temp_output
,
rsqrt_head_size
,
batch_size
,
sequence_length
,
num_heads
,
head_size
,
window
,
element_size
));
}
else
{
ORT_ENFORCE
(
max_num_global
<=
window
);
ORT_RETURN_IF_ERROR
(
LaunchLongformerSoftmaxKernel
(
stream
,
rocblas
,
workspace
,
q
,
k
,
v
,
attention_mask
,
max_num_global
,
compact_global_q
,
global_q
,
global_k
,
global_v
,
global_attention
,
global_index
,
batch_global_num
,
pinned_buffer
,
temp_output
,
rsqrt_head_size
,
batch_size
,
sequence_length
,
num_heads
,
head_size
,
window
,
element_size
));
}
// The temp_output is BxNxSxH, transpose it to final output BxSxNxH
return
LaunchTransCtx
(
stream
,
sequence_length
,
batch_size
,
head_size
,
num_heads
,
max_threads_per_block
,
false
,
temp_output
,
output
);
}
Status
LaunchLongformerAttentionKernel
(
const
hipDeviceProp_t
&
device_prop
,
rocblas_handle
rocblas
,
hipStream_t
stream
,
const
void
*
input
,
const
void
*
bias
,
const
void
*
attention_mask
,
const
void
*
global_input
,
const
void
*
global_bias
,
const
int
*
global_attention
,
const
int
*
global_index
,
const
int
*
batch_global_num
,
void
*
pinned_buffer
,
void
*
workspace
,
void
*
output
,
int
batch_size
,
int
sequence_length
,
int
num_heads
,
int
head_size
,
int
window
,
int
max_num_global
,
const
size_t
element_size
,
bool
disable_compact_memory
,
bool
use_merged_qkv_weights
,
bool
use_half4
)
{
CompatRocblasMathModeSetter
helper
(
device_prop
,
rocblas
,
0
/* CUBLAS_TENSOR_OP_MATH is deprecated */
);
size_t
softmax_workspace_size
=
GetLongformerSoftmaxWorkspaceSize
(
element_size
,
batch_size
,
num_heads
,
sequence_length
,
window
,
disable_compact_memory
);
if
(
element_size
==
2
)
{
return
LongformerQkvToContext
(
device_prop
,
rocblas
,
stream
,
batch_size
,
sequence_length
,
num_heads
,
head_size
,
window
,
element_size
,
reinterpret_cast
<
const
half
*>
(
input
),
reinterpret_cast
<
const
half
*>
(
bias
),
reinterpret_cast
<
const
half
*>
(
attention_mask
),
reinterpret_cast
<
const
half
*>
(
global_input
),
reinterpret_cast
<
const
half
*>
(
global_bias
),
global_attention
,
global_index
,
batch_global_num
,
max_num_global
,
pinned_buffer
,
reinterpret_cast
<
half
*>
(
workspace
),
reinterpret_cast
<
half
*>
(
output
),
softmax_workspace_size
,
disable_compact_memory
,
use_merged_qkv_weights
,
use_half4
);
}
else
{
return
LongformerQkvToContext
(
device_prop
,
rocblas
,
stream
,
batch_size
,
sequence_length
,
num_heads
,
head_size
,
window
,
element_size
,
reinterpret_cast
<
const
float
*>
(
input
),
reinterpret_cast
<
const
float
*>
(
bias
),
reinterpret_cast
<
const
float
*>
(
attention_mask
),
reinterpret_cast
<
const
float
*>
(
global_input
),
reinterpret_cast
<
const
float
*>
(
global_bias
),
global_attention
,
global_index
,
batch_global_num
,
max_num_global
,
pinned_buffer
,
reinterpret_cast
<
float
*>
(
workspace
),
reinterpret_cast
<
float
*>
(
output
),
softmax_workspace_size
,
disable_compact_memory
,
use_merged_qkv_weights
,
false
);
}
}
}
// namespace rocm
}
// namespace contrib
}
// namespace onnxruntime
build/Linux/Release/amdgpu/onnxruntime/contrib_ops/rocm/bert/longformer_attention_impl.h
0 → 100644
View file @
1a91fcc2
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include "core/providers/rocm/shared_inc/rocm_utils.h"
namespace
onnxruntime
{
namespace
contrib
{
namespace
rocm
{
size_t
GetPinnedBufferSize
(
size_t
batch_size
);
size_t
GetLongformerAttentionWorkspaceSize
(
size_t
element_size
,
size_t
batch_size
,
size_t
num_heads
,
size_t
head_size
,
size_t
sequence_length
,
size_t
max_num_global
,
size_t
window
,
bool
disable_compact_memory
);
Status
LaunchLongformerAttentionKernel
(
const
hipDeviceProp_t
&
device_prop
,
// Device Properties
rocblas_handle
rocblas
,
// Rocblas handle
hipStream_t
stream
,
// ROCM stream
const
void
*
input
,
// Input tensor
const
void
*
bias
,
// Bias tensor
const
void
*
attention_mask
,
// Attention mask with shape (B, S)
const
void
*
global_input
,
// Global attention input, or nullptr when max_num_global == 0.
const
void
*
global_bias
,
// Global bias tensor
const
int
*
global_attention
,
// Global attention flags with shape (B, S)
const
int
*
global_index
,
// Global index
const
int
*
batch_global_num
,
// Number of global tokens per batch. It is in device memory.
void
*
pinned_buffer
,
// Pinned memory: copy of batch_global_num, and a buffer to copy to scratch2.
void
*
workspace
,
// Temporary buffer
void
*
output
,
// Output tensor
int
batch_size
,
// Batch size (B)
int
sequence_length
,
// Sequence length (S)
int
num_heads
,
// Number of attention heads (N)
int
head_size
,
// Hidden layer size per head (H)
int
window
,
// One sided attention window (W)
int
max_num_global
,
// Maximum number of global tokens (G)
const
size_t
element_size
,
// Element size of input tensor,
bool
disable_compact_memory
,
// Disable compact memory kernel
bool
use_merged_qkv_weights
,
bool
use_half4
);
}
// namespace rocm
}
// namespace contrib
}
// namespace onnxruntime
build/Linux/Release/amdgpu/onnxruntime/contrib_ops/rocm/bert/longformer_attention_softmax.cu
0 → 100644
View file @
1a91fcc2
#include "hip/hip_runtime.h"
/*
Copyright (c) NVIDIA Corporation and Microsoft Corporation
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
// This is rocm kernels for longformer attention softmax that does not use compact memory.
// It uses two temporary matrix of BxNxSxS, and consumes more memory when sequence length is large.
// Its logic is simpler with less constraints (like number of global tokens could be larger than attention windows).
#include <hipcub/hipcub.hpp>
#include <rocblas.h>
#include <hip/hip_fp16.h>
#include <hip/hip_runtime.h>
#include <limits>
#include "core/providers/rocm/cu_inc/common.cuh"
#include "core/providers/rocm/rocm_common.h"
#include "contrib_ops/rocm/bert/longformer_attention_softmax.h"
#include "contrib_ops/rocm/bert/attention_impl.h"
using
namespace
onnxruntime
::
rocm
;
using
namespace
hipcub
;
#define CHECK(expr) ROCBLAS_RETURN_IF_ERROR(expr)
namespace
onnxruntime
{
namespace
contrib
{
namespace
rocm
{
template
<
typename
T
,
int
blockSize
>
__launch_bounds__
(
blockSize
)
__global__
void
LongformerSoftmaxSimpleKernel
(
const
int
*
global_attention
,
const
int
*
global_index
,
const
int
*
batch_global_num
,
const
T
*
input
,
const
T
*
attention_mask
,
T
*
output
,
float
scaler
,
int
dim0
,
int
sequence_length
,
int
attention_window
)
{
typedef
hipcub
::
BlockReduce
<
float
,
blockSize
>
BlockReduce
;
__shared__
typename
BlockReduce
::
TempStorage
block_reduce_temp
;
__shared__
float
max_shared
;
__shared__
float
sum_shared
;
const
T
*
input_block
=
input
+
sequence_length
*
blockIdx
.
x
;
T
*
output_block
=
output
+
sequence_length
*
blockIdx
.
x
;
const
int
batch_index
=
blockIdx
.
x
/
dim0
;
const
int
row_index
=
blockIdx
.
x
%
sequence_length
;
const
int
global_num
=
batch_global_num
[
batch_index
];
// To be consistent with Huggingface Longformer, the row of maksed word are set as zero.
if
((
float
)
attention_mask
[
batch_index
*
sequence_length
+
row_index
]
<
0.0
f
)
{
for
(
int
i
=
threadIdx
.
x
;
i
<
sequence_length
;
i
+=
blockSize
)
{
output_block
[
i
]
=
(
T
)(
0
);
}
return
;
}
// local attention token
int
col_start
=
0
;
int
col_end
=
sequence_length
;
bool
is_local_row
=
(
global_attention
[
batch_index
*
sequence_length
+
row_index
]
==
(
int
)
0
);
if
(
is_local_row
)
{
col_start
=
row_index
-
attention_window
;
if
(
col_start
<
0
)
{
col_start
=
0
;
}
col_end
=
row_index
+
attention_window
+
1
;
if
(
col_end
>
sequence_length
)
{
col_end
=
sequence_length
;
}
}
const
T
*
mask_block
=
attention_mask
+
sequence_length
*
batch_index
;
int
tid
=
threadIdx
.
x
;
// calculate max input
float
max_input
=
-
std
::
numeric_limits
<
float
>::
infinity
();
// #pragma unroll 16
for
(
int
i
=
tid
+
col_start
;
i
<
col_end
;
i
+=
blockSize
)
{
float
x
=
input_block
[
i
];
x
=
x
*
scaler
+
(
float
)
mask_block
[
i
];
if
(
max_input
<
x
)
{
max_input
=
x
;
}
}
if
(
is_local_row
)
{
for
(
int
g
=
tid
;
g
<
global_num
;
g
+=
blockSize
)
{
int
i
=
global_index
[
g
];
if
(
i
<
col_start
||
i
>=
col_end
)
{
float
x
=
input_block
[
i
];
x
=
x
*
scaler
+
(
float
)
mask_block
[
i
];
if
(
max_input
<
x
)
{
max_input
=
x
;
}
}
}
}
float
max_block
=
BlockReduce
(
block_reduce_temp
).
Reduce
(
max_input
,
hipcub
::
Max
());
if
(
tid
==
0
)
{
max_shared
=
max_block
;
}
__syncthreads
();
float
sum_input
=
0.
f
;
// #pragma unroll 16
for
(
int
i
=
tid
+
col_start
;
i
<
col_end
;
i
+=
blockSize
)
{
float
x
=
input_block
[
i
];
x
=
expf
((
x
)
*
scaler
+
(
float
)
mask_block
[
i
]
-
max_shared
);
sum_input
+=
x
;
}
if
(
is_local_row
)
{
for
(
int
g
=
tid
;
g
<
global_num
;
g
+=
blockSize
)
{
int
i
=
global_index
[
g
];
if
(
i
<
col_start
||
i
>=
col_end
)
{
float
x
=
input_block
[
i
];
x
=
expf
((
x
)
*
scaler
+
(
float
)
mask_block
[
i
]
-
max_shared
);
sum_input
+=
x
;
}
}
}
float
sum_block
=
BlockReduce
(
block_reduce_temp
).
Reduce
(
sum_input
,
hipcub
::
Sum
());
if
(
tid
==
0
)
{
sum_shared
=
sum_block
;
}
__syncthreads
();
float
recip_sum
=
1.
f
/
sum_shared
;
if
(
is_local_row
)
{
// We only need to fill in zeros for blocks that will be used in the matrix multiplication
// following the Softmax.
//
// For now zero-out only [row_index - 2*attention_window, row_index + 2*attention_window],
// we can even be more agressive and reduce the zeroing out window size since
// each row has entries in 3 blocks (3*attention_window size instead of 4*attention_window)
int
zero_start
=
row_index
-
2
*
attention_window
;
if
(
zero_start
<
0
)
{
zero_start
=
0
;
}
int
zero_end
=
row_index
+
2
*
attention_window
;
if
(
zero_end
>
sequence_length
)
{
zero_end
=
sequence_length
;
}
for
(
int
i
=
tid
+
zero_start
;
i
<
zero_end
;
i
+=
blockSize
)
{
if
(
i
<
col_start
||
i
>=
col_end
)
{
output_block
[
i
]
=
(
T
)(
0.
);
}
}
}
__syncthreads
();
if
(
is_local_row
)
{
for
(
int
g
=
tid
;
g
<
global_num
;
g
+=
blockSize
)
{
int
i
=
global_index
[
g
];
if
(
i
<
col_start
||
i
>=
col_end
)
{
float
x
=
input_block
[
i
];
x
=
expf
((
x
)
*
scaler
+
(
float
)
mask_block
[
i
]
-
max_shared
);
output_block
[
i
]
=
(
T
)(
recip_sum
*
x
);
}
}
}
// #pragma unroll 16
for
(
int
i
=
tid
+
col_start
;
i
<
col_end
;
i
+=
blockSize
)
{
float
x
=
input_block
[
i
];
x
=
expf
((
x
)
*
scaler
+
(
float
)
mask_block
[
i
]
-
max_shared
);
output_block
[
i
]
=
(
T
)(
recip_sum
*
x
);
}
}
// Launch the softmax kernel for non compact memory.
Status
LaunchLongformerSoftmaxSimpleKernel
(
hipStream_t
stream
,
rocblas_handle
rocblas
,
void
*
workspace
,
// softmax space
const
void
*
q
,
// transposed Q with shape (B, N, S, H)
const
void
*
k
,
// transposed K with shape (B, N, S, H)
const
void
*
v
,
// transposed V with shape (B, N, S, H)
const
void
*
attention_mask
,
// attention mask with shape (B, S), with value 0.0 not masked, and -10000.0 masked.
const
void
*
global_q
,
// Q for global tokens with shape (B, N, S, H)
const
void
*
global_k
,
// K for global tokens with shape (B, N, S, H)
const
void
*
global_v
,
// V for global tokens with shape (B, N, S, H)
const
int
*
global_attention
,
// global attention flags with shape (B, S), with value 0 for local and 1 for global.
const
int
*
global_index
,
// Global index with shape (B, S)
const
int
*
batch_global_num
,
// Number of global tokens per batch with shape (B, 1)
void
*
pinned_buffer
,
// Pinned memory in CPU. Number of global tokens per batch with shape (B, 1)
void
*
output
,
// output with shape (B, N, S, H)
float
scaler
,
// scalar
int
batch_size
,
// batch size
int
sequence_length
,
// sequence length
int
num_heads
,
// number of heads
int
head_size
,
// hidden size per head
int
attention_window
,
// one sided windows size
size_t
element_size
)
{
// size of element: 2 for half, and 4 for float
bool
is_fp16
=
(
element_size
==
2
);
void
*
scratch1
=
reinterpret_cast
<
char
*>
(
workspace
);
size_t
scratch1_size
=
GetAttentionScratchSize
(
element_size
,
batch_size
,
num_heads
,
sequence_length
,
sequence_length
);
void
*
scratch2
=
reinterpret_cast
<
char
*>
(
scratch1
)
+
scratch1_size
;
// setup shared parameters for two strided batched matrix multiplies
rocblas_datatype
Atype
;
rocblas_datatype
Btype
;
rocblas_datatype
Ctype
;
rocblas_datatype
resultType
;
rocblas_gemm_algo
algo
=
rocblas_gemm_algo_standard
;
__half
one_fp16
,
zero_fp16
;
float
one_fp32
,
zero_fp32
;
void
*
alpha
,
*
beta_0
,
*
beta_1
;
if
(
is_fp16
)
{
one_fp16
=
__float2half
(
1.
f
);
zero_fp16
=
__float2half
(
0.
f
);
alpha
=
static_cast
<
void
*>
(
&
one_fp16
);
beta_0
=
static_cast
<
void
*>
(
&
zero_fp16
);
beta_1
=
static_cast
<
void
*>
(
&
one_fp16
);
Atype
=
rocblas_datatype_f16_r
;
Btype
=
rocblas_datatype_f16_r
;
Ctype
=
rocblas_datatype_f16_r
;
resultType
=
rocblas_datatype_f16_r
;
algo
=
rocblas_gemm_algo_standard
;
}
else
{
one_fp32
=
1.
f
;
zero_fp32
=
0.
f
;
alpha
=
static_cast
<
void
*>
(
&
one_fp32
);
beta_0
=
static_cast
<
void
*>
(
&
zero_fp32
);
beta_1
=
static_cast
<
void
*>
(
&
one_fp32
);
Atype
=
rocblas_datatype_f32_r
;
Btype
=
rocblas_datatype_f32_r
;
Ctype
=
rocblas_datatype_f32_r
;
resultType
=
rocblas_datatype_f32_r
;
}
// Strided batch matrix multiply
// qk = q * k^T
// Shapes: q and k = B x N x S x H, qk = B x N x S x S
// Convert col-major to row-major by swapping q and k in Gemm
// Local attention part
// S x S is calculated using sliding block WxW (W is one sided window size) like the following:
// [W][W]
// [W][W][W]
// [W][W][W]
// [W][W]
// The first and last rows have 2 blocks, and the remaining has 3 blocks per row.
// The calculation are splited into 3 parts: Fill the middle rows, then the first row and finally the last row.
// The results are stored in scratch1.
int
w
=
attention_window
;
size_t
x_offset
=
static_cast
<
size_t
>
(
num_heads
)
*
sequence_length
*
head_size
;
// Use size_t to avoid integer overflow since B x N x S x S is 12G for B=64, N=12, S=4096
size_t
y_offset
=
static_cast
<
size_t
>
(
num_heads
)
*
sequence_length
*
sequence_length
;
int
last_block
=
(
sequence_length
/
w
)
-
1
;
int
strideA
=
sequence_length
*
head_size
;
int
strideB
=
sequence_length
*
head_size
;
int
strideC
=
sequence_length
*
sequence_length
;
// When S == 2W, there is no middle rows of blocks:
// [W][W]
// [W][W]
// We can use normal matrix multiplication in this case.
if
(
sequence_length
==
2
*
w
)
{
CHECK
(
_compat_rocblas_gemm_strided_batched_ex
(
rocblas
,
rocblas_operation_transpose
,
rocblas_operation_none
,
sequence_length
,
sequence_length
,
head_size
,
alpha
,
k
,
Atype
,
head_size
,
sequence_length
*
head_size
,
q
,
Btype
,
head_size
,
sequence_length
*
head_size
,
beta_0
,
scratch1
,
Ctype
,
sequence_length
,
sequence_length
*
sequence_length
,
batch_size
*
num_heads
,
resultType
,
algo
));
}
else
{
// sequence_length > 2 * w
for
(
int
i
=
0
;
i
<
batch_size
;
++
i
)
{
for
(
int
j
=
0
;
j
<
num_heads
;
++
j
)
{
const
void
*
q_head
=
reinterpret_cast
<
const
char
*>
(
q
)
+
(
i
*
x_offset
+
j
*
sequence_length
*
head_size
+
w
*
head_size
)
*
element_size
;
const
void
*
k_head
=
reinterpret_cast
<
const
char
*>
(
k
)
+
(
i
*
x_offset
+
j
*
sequence_length
*
head_size
)
*
element_size
;
void
*
qk_head
=
reinterpret_cast
<
char
*>
(
scratch1
)
+
(
i
*
y_offset
+
j
*
sequence_length
*
sequence_length
+
w
*
sequence_length
)
*
element_size
;
int
count
=
(
sequence_length
-
2
*
w
)
/
w
;
CHECK
(
_compat_rocblas_gemm_strided_batched_ex
(
rocblas
,
rocblas_operation_transpose
,
rocblas_operation_none
,
3
*
w
,
// m
w
,
// n
head_size
,
// k
alpha
,
// alpha
k_head
,
// A
Atype
,
// A type
head_size
,
// lda
w
*
head_size
,
// strideA
q_head
,
// B
Btype
,
// B type
head_size
,
// ldb
w
*
head_size
,
// strideB
beta_0
,
// beta
qk_head
,
// C
Ctype
,
// C type
sequence_length
,
// ldc
sequence_length
*
w
+
w
,
// strideC
count
,
// batch count
resultType
,
algo
));
}
}
CHECK
(
_compat_rocblas_gemm_strided_batched_ex
(
rocblas
,
rocblas_operation_transpose
,
rocblas_operation_none
,
2
*
w
,
// m
w
,
// n
head_size
,
// k
alpha
,
// alpha
k
,
// A
Atype
,
// A type
head_size
,
// lda
strideA
,
// strideA
q
,
// B
Btype
,
// B type
head_size
,
// ldb
strideB
,
// strideB
beta_0
,
// beta
scratch1
,
// C
Ctype
,
// C type
sequence_length
,
// ldc
strideC
,
// strideC
batch_size
*
num_heads
,
// batch count
resultType
,
algo
));
const
void
*
q_head
=
reinterpret_cast
<
const
char
*>
(
q
)
+
(
last_block
*
w
*
head_size
)
*
element_size
;
const
void
*
k_head
=
reinterpret_cast
<
const
char
*>
(
k
)
+
((
last_block
-
1
)
*
w
*
head_size
)
*
element_size
;
void
*
qk_head
=
reinterpret_cast
<
char
*>
(
scratch1
)
+
(
last_block
*
w
*
sequence_length
+
(
last_block
-
1
)
*
w
)
*
element_size
;
CHECK
(
_compat_rocblas_gemm_strided_batched_ex
(
rocblas
,
rocblas_operation_transpose
,
rocblas_operation_none
,
2
*
w
,
w
,
head_size
,
alpha
,
k_head
,
Atype
,
head_size
,
strideA
,
q_head
,
Btype
,
head_size
,
strideB
,
beta_0
,
qk_head
,
Ctype
,
sequence_length
,
strideC
,
batch_size
*
num_heads
,
resultType
,
algo
));
}
const
int
*
batch_global_count
=
reinterpret_cast
<
const
int
*>
(
pinned_buffer
);
// Global attention part
for
(
int
i
=
0
;
i
<
batch_size
;
++
i
)
{
if
(
batch_global_count
[
i
]
>
0
)
{
const
void
*
q_batch
=
reinterpret_cast
<
const
char
*>
(
q
)
+
(
i
*
x_offset
)
*
element_size
;
const
void
*
k_batch
=
reinterpret_cast
<
const
char
*>
(
k
)
+
(
i
*
x_offset
)
*
element_size
;
void
*
qk_batch
=
reinterpret_cast
<
char
*>
(
scratch1
)
+
(
i
*
y_offset
)
*
element_size
;
// Local tokens attending global tokens
CHECK
(
_compat_rocblas_gemm_strided_batched_ex
(
rocblas
,
rocblas_operation_transpose
,
rocblas_operation_none
,
batch_global_count
[
i
],
sequence_length
,
head_size
,
alpha
,
k_batch
,
Atype
,
head_size
,
strideA
,
q_batch
,
Btype
,
head_size
,
strideB
,
beta_0
,
qk_batch
,
Ctype
,
sequence_length
,
strideC
,
num_heads
,
resultType
,
algo
));
const
void
*
global_q_batch
=
reinterpret_cast
<
const
char
*>
(
global_q
)
+
(
i
*
num_heads
*
sequence_length
*
head_size
)
*
element_size
;
const
void
*
global_k_batch
=
reinterpret_cast
<
const
char
*>
(
global_k
)
+
(
i
*
x_offset
)
*
element_size
;
int
strideB_global
=
sequence_length
*
head_size
;
// Global tokens attending everything
// This GEMMs need to be last to make sure all global token entries are re-written.
CHECK
(
_compat_rocblas_gemm_strided_batched_ex
(
rocblas
,
rocblas_operation_transpose
,
rocblas_operation_none
,
sequence_length
,
batch_global_count
[
i
],
head_size
,
alpha
,
global_k_batch
,
Atype
,
head_size
,
strideA
,
global_q_batch
,
Btype
,
head_size
,
strideB_global
,
beta_0
,
qk_batch
,
Ctype
,
sequence_length
,
strideC
,
num_heads
,
resultType
,
algo
));
}
}
int
dim0
=
sequence_length
*
num_heads
;
int
dim1
=
sequence_length
;
void
*
softmax_out
=
scratch2
;
const
int
blockSize
=
64
;
const
int
gridSize
=
batch_size
*
num_heads
*
sequence_length
;
if
(
is_fp16
)
{
hipLaunchKernelGGL
(
HIP_KERNEL_NAME
(
LongformerSoftmaxSimpleKernel
<
__half
,
blockSize
>
),
gridSize
,
blockSize
,
0
,
stream
,
global_attention
,
global_index
,
batch_global_num
,
static_cast
<
const
__half
*>
(
scratch1
),
static_cast
<
const
__half
*>
(
attention_mask
),
static_cast
<
__half
*>
(
softmax_out
),
scaler
,
dim0
,
dim1
,
attention_window
);
}
else
{
hipLaunchKernelGGL
(
HIP_KERNEL_NAME
(
LongformerSoftmaxSimpleKernel
<
float
,
blockSize
>
),
gridSize
,
blockSize
,
0
,
stream
,
global_attention
,
global_index
,
batch_global_num
,
static_cast
<
const
float
*>
(
scratch1
),
static_cast
<
const
float
*>
(
attention_mask
),
static_cast
<
float
*>
(
softmax_out
),
scaler
,
dim0
,
dim1
,
attention_window
);
}
// Run the matrix multiply: output = softmax_out * v
// softmax_out: B x N x S x S
// v: B x N x S x H
// attn_out: B x N x S x H
// Calculation uses full Gemm (S == 2W) or sliding blocks (S > 2W) in a way similar to local attention part.
if
(
sequence_length
==
2
*
w
)
{
// convert col-major to row-major by swapping softmax_out and v
CHECK
(
_compat_rocblas_gemm_strided_batched_ex
(
rocblas
,
rocblas_operation_none
,
rocblas_operation_none
,
head_size
,
sequence_length
,
sequence_length
,
alpha
,
v
,
Atype
,
head_size
,
sequence_length
*
head_size
,
softmax_out
,
Btype
,
sequence_length
,
sequence_length
*
sequence_length
,
beta_0
,
output
,
Ctype
,
head_size
,
sequence_length
*
head_size
,
batch_size
*
num_heads
,
resultType
,
algo
));
}
else
{
// sequence_length > 2 * w
for
(
int
i
=
0
;
i
<
batch_size
;
++
i
)
{
for
(
int
j
=
0
;
j
<
num_heads
;
++
j
)
{
const
void
*
v_head
=
reinterpret_cast
<
const
char
*>
(
v
)
+
(
i
*
x_offset
+
j
*
head_size
*
sequence_length
)
*
element_size
;
size_t
offset
=
(
i
*
y_offset
+
j
*
sequence_length
*
sequence_length
+
w
*
sequence_length
)
*
element_size
;
const
void
*
prob_head
=
reinterpret_cast
<
const
char
*>
(
softmax_out
)
+
offset
;
void
*
out_head
=
reinterpret_cast
<
char
*>
(
output
)
+
(
i
*
x_offset
+
j
*
head_size
*
sequence_length
+
w
*
head_size
)
*
element_size
;
int
count
=
(
sequence_length
-
2
*
w
)
/
w
;
CHECK
(
_compat_rocblas_gemm_strided_batched_ex
(
rocblas
,
rocblas_operation_none
,
rocblas_operation_none
,
head_size
,
w
,
3
*
w
,
alpha
,
v_head
,
Atype
,
head_size
,
w
*
head_size
,
prob_head
,
Btype
,
sequence_length
,
sequence_length
*
w
+
w
,
beta_0
,
out_head
,
Ctype
,
head_size
,
w
*
head_size
,
count
,
resultType
,
algo
));
}
}
CHECK
(
_compat_rocblas_gemm_strided_batched_ex
(
rocblas
,
rocblas_operation_none
,
rocblas_operation_none
,
head_size
,
w
,
2
*
w
,
alpha
,
v
,
Atype
,
head_size
,
sequence_length
*
head_size
,
softmax_out
,
Btype
,
sequence_length
,
sequence_length
*
sequence_length
,
beta_0
,
output
,
Ctype
,
head_size
,
sequence_length
*
head_size
,
batch_size
*
num_heads
,
resultType
,
algo
));
const
void
*
v_head
=
reinterpret_cast
<
const
char
*>
(
v
)
+
(
last_block
-
1
)
*
w
*
head_size
*
element_size
;
const
void
*
prob_head
=
reinterpret_cast
<
const
char
*>
(
softmax_out
)
+
(
sequence_length
*
last_block
*
w
+
(
last_block
-
1
)
*
w
)
*
element_size
;
void
*
out_head
=
reinterpret_cast
<
char
*>
(
output
)
+
last_block
*
w
*
head_size
*
element_size
;
CHECK
(
_compat_rocblas_gemm_strided_batched_ex
(
rocblas
,
rocblas_operation_none
,
rocblas_operation_none
,
head_size
,
w
,
2
*
w
,
alpha
,
v_head
,
Atype
,
head_size
,
sequence_length
*
head_size
,
prob_head
,
Btype
,
sequence_length
,
sequence_length
*
sequence_length
,
beta_0
,
out_head
,
Ctype
,
head_size
,
sequence_length
*
head_size
,
batch_size
*
num_heads
,
resultType
,
algo
));
}
for
(
int
i
=
0
;
i
<
batch_size
;
++
i
)
{
if
(
batch_global_count
[
i
]
>
0
)
{
int
glob_longdim_mm
=
(
last_block
-
1
)
*
w
;
const
void
*
v_head
=
reinterpret_cast
<
const
char
*>
(
v
)
+
(
i
*
x_offset
)
*
element_size
;
const
void
*
prob_head
=
reinterpret_cast
<
const
char
*>
(
softmax_out
)
+
(
i
*
y_offset
+
2
*
w
*
sequence_length
)
*
element_size
;
void
*
out_head
=
reinterpret_cast
<
char
*>
(
output
)
+
(
i
*
x_offset
+
2
*
w
*
head_size
)
*
element_size
;
CHECK
(
_compat_rocblas_gemm_strided_batched_ex
(
rocblas
,
rocblas_operation_none
,
rocblas_operation_none
,
head_size
,
glob_longdim_mm
,
batch_global_count
[
i
],
alpha
,
v_head
,
Atype
,
head_size
,
sequence_length
*
head_size
,
prob_head
,
Btype
,
sequence_length
,
sequence_length
*
sequence_length
,
beta_1
,
out_head
,
Ctype
,
head_size
,
sequence_length
*
head_size
,
num_heads
,
resultType
,
algo
));
// Global tokens
v_head
=
reinterpret_cast
<
const
char
*>
(
global_v
)
+
(
i
*
x_offset
)
*
element_size
;
prob_head
=
reinterpret_cast
<
const
char
*>
(
softmax_out
)
+
(
i
*
y_offset
)
*
element_size
;
out_head
=
reinterpret_cast
<
char
*>
(
output
)
+
(
i
*
x_offset
)
*
element_size
;
CHECK
(
_compat_rocblas_gemm_strided_batched_ex
(
rocblas
,
rocblas_operation_none
,
rocblas_operation_none
,
head_size
,
batch_global_count
[
i
],
sequence_length
,
// Re-write entries completely
alpha
,
v_head
,
Atype
,
head_size
,
sequence_length
*
head_size
,
prob_head
,
Btype
,
sequence_length
,
sequence_length
*
sequence_length
,
beta_0
,
// Use beta=0 to overwrite
out_head
,
// Here assumes global tokens are at the beginning of sequence.
Ctype
,
head_size
,
sequence_length
*
head_size
,
num_heads
,
resultType
,
algo
));
}
}
return
Status
::
OK
();
}
}
// namespace rocm
}
// namespace contrib
}
// namespace onnxruntime
Prev
1
2
3
4
5
…
14
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment