Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
98212afe
"docs/en/vscode:/vscode.git/clone" did not exist on "d862f570aa95e6df5da12fa36178ec912aceced9"
Commit
98212afe
authored
Aug 12, 2023
by
letaoqin
Browse files
batched gemm reduce interface
parent
102c9661
Changes
5
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
233 additions
and
304 deletions
+233
-304
example/52_flash_atten_bias/batched_multihead_attention_bias_forward_v2.cpp
...tten_bias/batched_multihead_attention_bias_forward_v2.cpp
+2
-2
example/52_flash_atten_bias/run_batched_multihead_attention_bias_forward.inc
...ten_bias/run_batched_multihead_attention_bias_forward.inc
+13
-16
include/ck/tensor_operation/gpu/device/device_batched_gemm_softmax_gemm_permute.hpp
...n/gpu/device/device_batched_gemm_softmax_gemm_permute.hpp
+8
-8
include/ck/tensor_operation/gpu/device/impl/device_batched_mha_fwd_xdl_cshuffle_v2.hpp
...pu/device/impl/device_batched_mha_fwd_xdl_cshuffle_v2.hpp
+158
-192
include/ck/tensor_operation/gpu/grid/gridwise_batched_mha_fwd_xdl_cshuffle_v2.hpp
...ion/gpu/grid/gridwise_batched_mha_fwd_xdl_cshuffle_v2.hpp
+52
-86
No files found.
example/52_flash_atten_bias/batched_multihead_attention_bias_forward_v2.cpp
View file @
98212afe
...
@@ -53,8 +53,8 @@ using CDataType = DataType;
...
@@ -53,8 +53,8 @@ using CDataType = DataType;
using
DDataType
=
F16
;
using
DDataType
=
F16
;
using
ZDataType
=
U16
;
// INT32
using
ZDataType
=
U16
;
// INT32
using
LSEDataType
=
F32
;
using
LSEDataType
=
F32
;
using
Acc0BiasDataType
=
ck
::
Tuple
<
DDataType
>
;
using
Acc0BiasDataType
=
DDataType
;
using
Acc1BiasDataType
=
ck
::
Tuple
<>
;
using
Acc1BiasDataType
=
void
;
static
constexpr
ck
::
index_t
NumDimG
=
2
;
static
constexpr
ck
::
index_t
NumDimG
=
2
;
static
constexpr
ck
::
index_t
NumDimM
=
1
;
static
constexpr
ck
::
index_t
NumDimM
=
1
;
...
...
example/52_flash_atten_bias/run_batched_multihead_attention_bias_forward.inc
View file @
98212afe
...
@@ -190,8 +190,8 @@ int run(int argc, char* argv[])
...
@@ -190,8 +190,8 @@ int run(int argc, char* argv[])
static_cast
<
CDataType
*>
(
c_device_buf
.
GetDeviceBuffer
()),
static_cast
<
CDataType
*>
(
c_device_buf
.
GetDeviceBuffer
()),
static_cast
<
ZDataType
*>
(
nullptr
),
static_cast
<
ZDataType
*>
(
nullptr
),
static_cast
<
LSEDataType
*>
(
lse_device_buf
.
GetDeviceBuffer
()),
static_cast
<
LSEDataType
*>
(
lse_device_buf
.
GetDeviceBuffer
()),
st
d
::
array
<
void
*
,
1
>
{
d_device_buf
.
GetDeviceBuffer
()
}
,
//
std::array<void*, 1> p_acc0_biases;
st
atic_cast
<
DDataType
*>
(
d_device_buf
.
GetDeviceBuffer
()
)
,
//
{}
,
// std::array<void*, 1> p_acc1_biases;
nullptr
,
a_gs_ms_ks_lengths
,
a_gs_ms_ks_lengths
,
a_gs_ms_ks_strides
,
a_gs_ms_ks_strides
,
b0_gs_ns_ks_lengths
,
b0_gs_ns_ks_lengths
,
...
@@ -203,10 +203,10 @@ int run(int argc, char* argv[])
...
@@ -203,10 +203,10 @@ int run(int argc, char* argv[])
z_gs_ms_ns_lengths
,
z_gs_ms_ns_lengths
,
z_gs_ms_ns_strides
,
z_gs_ms_ns_strides
,
lse_gs_ms_lengths
,
lse_gs_ms_lengths
,
std
::
array
<
std
::
vector
<
ck
::
index_t
>
,
1
>
{
d_gs_ms_ns_lengths
}
,
// acc0_biases_gs_ms_ns_lengths
d_gs_ms_ns_lengths
,
// acc0_biases_gs_ms_ns_lengths
std
::
array
<
std
::
vector
<
ck
::
index_t
>
,
1
>
{
d_gs_ms_ns_strides
}
,
// acc0_biases_gs_ms_ns_strides
d_gs_ms_ns_strides
,
// acc0_biases_gs_ms_ns_strides
{},
//
std::array<
std::vector<ck::index_t>
, 1>{acc1_biases_gs_ms_os_lengths},
{},
// std::vector<ck::index_t>
{},
//
std::array<
std::vector<ck::index_t>
, 1>{acc1_biases_gs_ms_os_strides},
{},
// std::vector<ck::index_t>
a_element_op
,
a_element_op
,
b0_element_op
,
b0_element_op
,
acc0_element_op
,
acc0_element_op
,
...
@@ -230,7 +230,7 @@ int run(int argc, char* argv[])
...
@@ -230,7 +230,7 @@ int run(int argc, char* argv[])
std
::
size_t
flop
=
(
size_t
(
M
)
*
N
*
K
*
2
+
size_t
(
M
)
*
N
*
O
*
2
)
*
BatchCount
;
std
::
size_t
flop
=
(
size_t
(
M
)
*
N
*
K
*
2
+
size_t
(
M
)
*
N
*
O
*
2
)
*
BatchCount
;
std
::
size_t
num_btype
=
(
sizeof
(
ADataType
)
*
M
*
K
+
sizeof
(
B0DataType
)
*
K
*
N
+
std
::
size_t
num_btype
=
(
sizeof
(
ADataType
)
*
M
*
K
+
sizeof
(
B0DataType
)
*
K
*
N
+
sizeof
(
B1DataType
)
*
N
*
O
+
sizeof
(
CDataType
)
*
M
*
O
+
sizeof
(
B1DataType
)
*
N
*
O
+
sizeof
(
CDataType
)
*
M
*
O
+
sizeof
(
DDataType
)
*
M
*
N
*
Acc0Bias
DataType
::
Size
()
)
*
sizeof
(
DDataType
)
*
M
*
N
*
std
::
is_void
<
D
DataType
>
::
value
?
1
:
0
)
*
BatchCount
;
BatchCount
;
float
tflops
=
static_cast
<
float
>
(
flop
)
/
1.E9
/
ave_time
;
float
tflops
=
static_cast
<
float
>
(
flop
)
/
1.E9
/
ave_time
;
...
@@ -250,9 +250,8 @@ int run(int argc, char* argv[])
...
@@ -250,9 +250,8 @@ int run(int argc, char* argv[])
static_cast
<
CDataType
*>
(
c_device_buf
.
GetDeviceBuffer
()),
static_cast
<
CDataType
*>
(
c_device_buf
.
GetDeviceBuffer
()),
static_cast
<
ZDataType
*>
(
z_device_buf
.
GetDeviceBuffer
()),
static_cast
<
ZDataType
*>
(
z_device_buf
.
GetDeviceBuffer
()),
static_cast
<
LSEDataType
*>
(
lse_device_buf
.
GetDeviceBuffer
()),
static_cast
<
LSEDataType
*>
(
lse_device_buf
.
GetDeviceBuffer
()),
std
::
array
<
void
*
,
1
>
{
static_cast
<
DDataType
*>
(
d_device_buf
.
GetDeviceBuffer
()),
d_device_buf
.
GetDeviceBuffer
()},
// std::array<void*, 1> p_acc0_biases;
nullptr
,
{},
// std::array<void*, 1> p_acc1_biases;
a_gs_ms_ks_lengths
,
a_gs_ms_ks_lengths
,
a_gs_ms_ks_strides
,
a_gs_ms_ks_strides
,
b0_gs_ns_ks_lengths
,
b0_gs_ns_ks_lengths
,
...
@@ -264,12 +263,10 @@ int run(int argc, char* argv[])
...
@@ -264,12 +263,10 @@ int run(int argc, char* argv[])
z_gs_ms_ns_lengths
,
z_gs_ms_ns_lengths
,
z_gs_ms_ns_strides
,
z_gs_ms_ns_strides
,
lse_gs_ms_lengths
,
lse_gs_ms_lengths
,
std
::
array
<
std
::
vector
<
ck
::
index_t
>
,
1
>
{
d_gs_ms_ns_lengths
,
d_gs_ms_ns_lengths
},
// acc0_biases_gs_ms_ns_lengths
d_gs_ms_ns_strides
,
std
::
array
<
std
::
vector
<
ck
::
index_t
>
,
1
>
{
{},
d_gs_ms_ns_strides
},
// acc0_biases_gs_ms_ns_strides
{},
{},
// std::array<std::vector<ck::index_t>, 1>{acc1_biases_gs_ms_os_lengths},
{},
// std::array<std::vector<ck::index_t>, 1>{acc1_biases_gs_ms_os_strides},
a_element_op
,
a_element_op
,
b0_element_op
,
b0_element_op
,
acc0_element_op
,
acc0_element_op
,
...
...
include/ck/tensor_operation/gpu/device/device_batched_gemm_softmax_gemm_permute.hpp
View file @
98212afe
...
@@ -87,8 +87,8 @@ template <index_t NumDimG,
...
@@ -87,8 +87,8 @@ template <index_t NumDimG,
MaskingSpecialization
MaskingSpec
>
MaskingSpecialization
MaskingSpec
>
struct
DeviceBatchedMultiheadAttentionForward
:
public
BaseOperator
struct
DeviceBatchedMultiheadAttentionForward
:
public
BaseOperator
{
{
static
constexpr
index_t
NumAcc0Bias
=
Acc0BiasDataType
::
Size
()
;
static
constexpr
index_t
NumAcc0Bias
=
1
;
static
constexpr
index_t
NumAcc1Bias
=
Acc1BiasDataType
::
Size
()
;
static
constexpr
index_t
NumAcc1Bias
=
0
;
virtual
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
virtual
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
const
void
*
p_a
,
const
void
*
p_a
,
...
@@ -97,8 +97,8 @@ struct DeviceBatchedMultiheadAttentionForward : public BaseOperator
...
@@ -97,8 +97,8 @@ struct DeviceBatchedMultiheadAttentionForward : public BaseOperator
void
*
p_c
,
void
*
p_c
,
void
*
p_z
,
void
*
p_z
,
void
*
p_lse
,
void
*
p_lse
,
const
std
::
array
<
void
*
,
NumAcc0Bias
>
p_acc0_biases
,
const
void
*
p_acc0_biases
,
const
std
::
array
<
void
*
,
NumAcc1Bias
>
p_acc1_biases
,
const
void
*
p_acc1_biases
,
const
std
::
vector
<
index_t
>&
a_gs_ms_ks_lengths
,
const
std
::
vector
<
index_t
>&
a_gs_ms_ks_lengths
,
const
std
::
vector
<
index_t
>&
a_gs_ms_ks_strides
,
const
std
::
vector
<
index_t
>&
a_gs_ms_ks_strides
,
const
std
::
vector
<
index_t
>&
b_gs_ns_ks_lengths
,
const
std
::
vector
<
index_t
>&
b_gs_ns_ks_lengths
,
...
@@ -110,11 +110,11 @@ struct DeviceBatchedMultiheadAttentionForward : public BaseOperator
...
@@ -110,11 +110,11 @@ struct DeviceBatchedMultiheadAttentionForward : public BaseOperator
const
std
::
vector
<
index_t
>&
z_gs_ms_ns_lengths
,
// z_gs_ms_os_lengths
const
std
::
vector
<
index_t
>&
z_gs_ms_ns_lengths
,
// z_gs_ms_os_lengths
const
std
::
vector
<
index_t
>&
z_gs_ms_ns_strides
,
// z_gs_ms_os_strides
const
std
::
vector
<
index_t
>&
z_gs_ms_ns_strides
,
// z_gs_ms_os_strides
const
std
::
vector
<
index_t
>&
lse_gs_ms_lengths
,
// lse_gs_ms_lengths
const
std
::
vector
<
index_t
>&
lse_gs_ms_lengths
,
// lse_gs_ms_lengths
const
std
::
array
<
std
::
vector
<
index_t
>
,
NumAcc0Bias
>
acc0_biases_gs_ms_ns_lengths
,
const
std
::
vector
<
index_t
>
&
acc0_biases_gs_ms_ns_lengths
,
const
std
::
array
<
std
::
vector
<
index_t
>
,
NumAcc0Bias
>
acc0_biases_gs_ms_ns_strides
,
const
std
::
vector
<
index_t
>
&
acc0_biases_gs_ms_ns_strides
,
const
std
::
array
<
std
::
vector
<
index_t
>
,
NumAcc1Bias
>
const
std
::
vector
<
index_t
>
&
acc1_biases_gs_ms_gemm1ns_lengths
,
// acc1_biases_gs_ms_os_lengths
acc1_biases_gs_ms_gemm1ns_lengths
,
// acc1_biases_gs_ms_os_lengths
const
std
::
array
<
std
::
vector
<
index_t
>
,
NumAcc1Bias
>
const
std
::
vector
<
index_t
>
&
acc1_biases_gs_ms_gemm1ns_strides
,
// acc1_biases_gs_ms_os_strides
acc1_biases_gs_ms_gemm1ns_strides
,
// acc1_biases_gs_ms_os_strides
AElementwiseOperation
a_element_op
,
AElementwiseOperation
a_element_op
,
B0ElementwiseOperation
b0_element_op
,
B0ElementwiseOperation
b0_element_op
,
...
...
include/ck/tensor_operation/gpu/device/impl/device_batched_mha_fwd_xdl_cshuffle_v2.hpp
View file @
98212afe
This diff is collapsed.
Click to expand it.
include/ck/tensor_operation/gpu/grid/gridwise_batched_mha_fwd_xdl_cshuffle_v2.hpp
View file @
98212afe
...
@@ -25,7 +25,7 @@ namespace ck {
...
@@ -25,7 +25,7 @@ namespace ck {
*
*
*/
*/
template
<
typename
FloatAB
,
template
<
typename
FloatAB
,
typename
D0
s
DataType
,
typename
D0DataType
,
typename
ZDataType
,
typename
ZDataType
,
typename
FloatGemm
,
typename
FloatGemm
,
typename
FloatGemmAcc
,
typename
FloatGemmAcc
,
...
@@ -40,7 +40,7 @@ template <typename FloatAB,
...
@@ -40,7 +40,7 @@ template <typename FloatAB,
InMemoryDataOperationEnum
CGlobalMemoryDataOperation
,
InMemoryDataOperationEnum
CGlobalMemoryDataOperation
,
typename
AGridDesc_AK0_M_AK1
,
typename
AGridDesc_AK0_M_AK1
,
typename
BGridDesc_BK0_N_BK1
,
typename
BGridDesc_BK0_N_BK1
,
typename
D0
s
GridDesc_M_N
,
typename
D0GridDesc_M_N
,
typename
B1GridDesc_BK0_N_BK1
,
typename
B1GridDesc_BK0_N_BK1
,
typename
CGridDesc_M_N
,
typename
CGridDesc_M_N
,
typename
ZGridDesc_M_N
,
typename
ZGridDesc_M_N
,
...
@@ -102,7 +102,6 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
...
@@ -102,7 +102,6 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
D0BlockTransferSrcScalarPerVector
==
2
||
D0BlockTransferSrcScalarPerVector
==
2
||
D0BlockTransferSrcScalarPerVector
==
4
,
D0BlockTransferSrcScalarPerVector
==
4
,
"D0BlockTransferSrcScalarPerVector must be 1 or 2 or 4"
);
"D0BlockTransferSrcScalarPerVector must be 1 or 2 or 4"
);
static
constexpr
index_t
NumD0Tensor
=
D0sDataType
::
Size
();
static_assert
(
LoopSched
==
LoopScheduler
::
Default
,
static_assert
(
LoopSched
==
LoopScheduler
::
Default
,
"Non-default loop scheduler is currently not supported"
);
"Non-default loop scheduler is currently not supported"
);
...
@@ -441,20 +440,9 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
...
@@ -441,20 +440,9 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
c_grid_desc_m_n
);
c_grid_desc_m_n
);
}
}
static
constexpr
auto
MakeD0sGridPointer
()
{
return
generate_tuple
(
[
&
](
auto
i
)
{
using
D0DataType
=
remove_cvref_t
<
tuple_element_t
<
i
.
value
,
D0sDataType
>>
;
return
static_cast
<
const
D0DataType
*>
(
nullptr
);
},
Number
<
NumD0Tensor
>
{});
}
// D0 desc for source in blockwise copy
// D0 desc for source in blockwise copy
template
<
typename
D0GridDesc_M_N
>
__host__
__device__
static
constexpr
auto
__host__
__device__
static
constexpr
auto
Make
Gemm0
D0GridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5
(
const
D0GridDesc_M_N
&
d0_grid_desc_m_n
)
MakeD0GridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5
(
const
D0GridDesc_M_N
&
d0_grid_desc_m_n
)
{
{
const
auto
M
=
d0_grid_desc_m_n
.
GetLength
(
I0
);
const
auto
M
=
d0_grid_desc_m_n
.
GetLength
(
I0
);
const
auto
N
=
d0_grid_desc_m_n
.
GetLength
(
I1
);
const
auto
N
=
d0_grid_desc_m_n
.
GetLength
(
I1
);
...
@@ -472,20 +460,8 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
...
@@ -472,20 +460,8 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
make_tuple
(
Sequence
<
0
,
2
,
4
,
6
>
{},
Sequence
<
1
,
3
,
5
,
7
,
8
,
9
>
{}));
make_tuple
(
Sequence
<
0
,
2
,
4
,
6
>
{},
Sequence
<
1
,
3
,
5
,
7
,
8
,
9
>
{}));
}
}
// D0s desc for source in blockwise copy
using
D0GridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5
=
remove_cvref_t
<
decltype
(
__host__
__device__
static
constexpr
auto
MakeD0GridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5
(
D0GridDesc_M_N
{}))
>
;
MakeD0sGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5
(
const
D0sGridDesc_M_N
&
ds_grid_desc_m_n
)
{
return
generate_tuple
(
[
&
](
auto
i
)
{
return
MakeGemm0D0GridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5
(
ds_grid_desc_m_n
[
i
]);
},
Number
<
NumD0Tensor
>
{});
}
using
D0sGridPointer
=
decltype
(
MakeD0sGridPointer
());
using
D0sGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5
=
remove_cvref_t
<
decltype
(
MakeD0sGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5
(
D0sGridDesc_M_N
{}))
>
;
using
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
=
remove_cvref_t
<
decltype
(
using
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
=
remove_cvref_t
<
decltype
(
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
CGridDesc_M_N
{}))
>
;
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
CGridDesc_M_N
{}))
>
;
...
@@ -544,7 +520,7 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
...
@@ -544,7 +520,7 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
typename
C0MatrixMask
>
typename
C0MatrixMask
>
__device__
static
void
Run
(
const
FloatAB
*
__restrict__
p_a_grid
,
__device__
static
void
Run
(
const
FloatAB
*
__restrict__
p_a_grid
,
const
FloatAB
*
__restrict__
p_b_grid
,
const
FloatAB
*
__restrict__
p_b_grid
,
D0sGridPointer
p_d0
s
_grid
,
const
D0DataType
*
__restrict__
p_d0_grid
,
const
FloatAB
*
__restrict__
p_b1_grid
,
const
FloatAB
*
__restrict__
p_b1_grid
,
FloatC
*
__restrict__
p_c_grid
,
FloatC
*
__restrict__
p_c_grid
,
ZDataType
*
__restrict__
p_z_grid
,
ZDataType
*
__restrict__
p_z_grid
,
...
@@ -557,8 +533,8 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
...
@@ -557,8 +533,8 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
const
CElementwiseOperation
&
c_element_op
,
const
CElementwiseOperation
&
c_element_op
,
const
AGridDesc_AK0_M_AK1
&
a_grid_desc_ak0_m_ak1
,
const
AGridDesc_AK0_M_AK1
&
a_grid_desc_ak0_m_ak1
,
const
BGridDesc_BK0_N_BK1
&
b_grid_desc_bk0_n_bk1
,
const
BGridDesc_BK0_N_BK1
&
b_grid_desc_bk0_n_bk1
,
const
D0
s
GridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5
&
const
D0GridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5
&
d0
s
_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
,
d0_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
,
const
B1GridDesc_BK0_N_BK1
&
b1_grid_desc_bk0_n_bk1
,
const
B1GridDesc_BK0_N_BK1
&
b1_grid_desc_bk0_n_bk1
,
const
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
&
const
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
&
c_grid_desc_mblock_mperblock_nblock_nperblock
,
c_grid_desc_mblock_mperblock_nblock_nperblock
,
...
@@ -985,49 +961,39 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
...
@@ -985,49 +961,39 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
n3
,
// NInputNum
n3
,
// NInputNum
n4
));
// RegisterNum
n4
));
// RegisterNum
auto
d0s_threadwise_copy
=
generate_tuple
(
auto
d0_threadwise_copy
=
[
&
](
auto
i
)
{
ThreadwiseTensorSliceTransfer_v2
<
D0DataType
,
using
D0DataType
=
remove_cvref_t
<
tuple_element_t
<
i
.
value
,
D0sDataType
>>
;
D0DataType
,
return
ThreadwiseTensorSliceTransfer_v2
<
decltype
(
d0_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
),
D0DataType
,
decltype
(
d0_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
),
D0DataType
,
Sequence
<
I1
,
// MBlockId
decltype
(
d0s_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
[
i
]),
I1
,
// NBlockID
decltype
(
d0_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
),
m0
,
// MRepeat
Sequence
<
I1
,
// MBlockId
n0
,
// NRepeat
I1
,
// NBlockID
m1
,
// MWaveId
m0
,
// MRepeat
n1
,
// NWaveId
n0
,
// NRepeat
m2
,
// MPerXdl
m1
,
// MWaveId
n2
,
// NGroupNum
n1
,
// NWaveId
n3
,
// NInputNum
m2
,
// MPerXdl
n4
>
,
n2
,
// NGroupNum
Sequence
<
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
>
,
n3
,
// NInputNum
9
,
n4
>
,
D0BlockTransferSrcScalarPerVector
,
Sequence
<
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
>
,
1
,
9
,
false
>
(
d0_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
,
D0BlockTransferSrcScalarPerVector
,
make_multi_index
(
block_work_idx_m
,
// MBlockId
1
,
0
,
// NBlockId
false
>
(
d0s_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
[
i
],
0
,
// mrepeat
make_multi_index
(
block_work_idx_m
,
// MBlockId
0
,
// nrepeat
0
,
// NBlockId
wave_id
[
I0
],
// MWaveId
0
,
// mrepeat
wave_id
[
I1
],
// NWaveId
0
,
// nrepeat
wave_m_n_id
[
I1
],
// MPerXdl
wave_id
[
I0
],
// MWaveId
0
,
// group
wave_id
[
I1
],
// NWaveId
wave_m_n_id
[
I0
],
// NInputIndex
wave_m_n_id
[
I1
],
// MPerXdl
0
));
// register number
0
,
// group
wave_m_n_id
[
I0
],
// NInputIndex
const
auto
d0_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
0
));
// register number
p_d0_grid
,
d0_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
.
GetElementSpaceSize
());
},
Number
<
NumD0Tensor
>
{});
const
auto
d0s_grid_buf
=
generate_tuple
(
[
&
](
auto
i
)
{
return
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_d0s_grid
[
i
],
d0s_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
[
i
].
GetElementSpaceSize
());
},
Number
<
NumD0Tensor
>
{});
constexpr
auto
z_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
=
// for blockwise copy
constexpr
auto
z_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
=
// for blockwise copy
make_naive_tensor_descriptor_packed
(
make_tuple
(
m0
,
// MRepeat
make_naive_tensor_descriptor_packed
(
make_tuple
(
m0
,
// MRepeat
...
@@ -1325,9 +1291,9 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
...
@@ -1325,9 +1291,9 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
block_sync_lds
();
// wait for lds read in gemm0 blockwise gemm
block_sync_lds
();
// wait for lds read in gemm0 blockwise gemm
// add bias
// add bias
static_for
<
0
,
NumD0Tensor
,
1
>
{}([
&
](
auto
i
)
{
if
constexpr
(
std
::
is_void
<
D0DataType
>::
value
)
{
// get register
// get register
using
D0DataType
=
remove_cvref_t
<
tuple_element_t
<
i
.
value
,
D0sDataType
>>
;
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
D0DataType
,
D0DataType
,
d0_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
.
GetElementSpaceSize
(),
d0_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
.
GetElementSpaceSize
(),
...
@@ -1335,20 +1301,20 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
...
@@ -1335,20 +1301,20 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
d0_thread_buf
;
d0_thread_buf
;
// load data from global
// load data from global
d0
s
_threadwise_copy
(
i
)
.
Run
(
d0
s
_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
[
i
]
,
d0_threadwise_copy
.
Run
(
d0_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
,
d0
s
_grid_buf
[
i
]
,
d0_grid_buf
,
d0_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
,
d0_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
,
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
),
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
),
d0_thread_buf
);
d0_thread_buf
);
// acc add bias
// acc add bias
static_for
<
0
,
m0
*
n0
*
n2
*
n4
,
1
>
{}(
static_for
<
0
,
m0
*
n0
*
n2
*
n4
,
1
>
{}(
[
&
](
auto
j
)
{
acc_thread_buf
(
j
)
+=
d0_thread_buf
[
j
];
});
[
&
](
auto
i
)
{
acc_thread_buf
(
i
)
+=
d0_thread_buf
[
i
];
});
d0
s
_threadwise_copy
(
i
)
.
MoveSrcSliceWindow
(
d0_threadwise_copy
.
MoveSrcSliceWindow
(
d0
s
_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
[
i
]
,
d0_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
,
make_multi_index
(
0
,
1
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
));
make_multi_index
(
0
,
1
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
));
}
);
}
// softmax
// softmax
SoftmaxBuf
&
max
=
blockwise_softmax
.
max_value_buf
;
SoftmaxBuf
&
max
=
blockwise_softmax
.
max_value_buf
;
SoftmaxBuf
&
sum
=
blockwise_softmax
.
sum_value_buf
;
SoftmaxBuf
&
sum
=
blockwise_softmax
.
sum_value_buf
;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment