Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
289e1196
"vscode:/vscode.git/clone" did not exist on "3848606c7ed98c585b7a41397f99e1a873b17f61"
Commit
289e1196
authored
Aug 23, 2023
by
letaoqin
Browse files
multiple M block
parent
a33c100d
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
45 additions
and
42 deletions
+45
-42
example/52_flash_atten_bias/batched_multihead_attention_bias_backward_v2.cpp
...ten_bias/batched_multihead_attention_bias_backward_v2.cpp
+3
-3
include/ck/tensor_operation/gpu/grid/gridwise_batched_mha_bwd_xdl_cshuffle_qloop_b2t_v2.hpp
...id/gridwise_batched_mha_bwd_xdl_cshuffle_qloop_b2t_v2.hpp
+42
-39
No files found.
example/52_flash_atten_bias/batched_multihead_attention_bias_backward_v2.cpp
View file @
289e1196
...
...
@@ -273,8 +273,8 @@ int run(int argc, char* argv[])
// y_g_m_o = Softmax(alpha * Q_g_m_k * K_g_k_n) * V_g_n_o
// y_g0_g1_m_o = reshape(y_g_m_o, [G0, G1, M, O])
// y_g0_m_g1_o = permute(y_g0_g1_m_o, [0, 2, 1, 3])
ck
::
index_t
M
=
64
;
ck
::
index_t
N
=
12
8
;
ck
::
index_t
M
=
128
;
ck
::
index_t
N
=
5
12
;
ck
::
index_t
K
=
DIM
;
ck
::
index_t
O
=
DIM
;
ck
::
index_t
G0
=
1
;
...
...
@@ -468,7 +468,7 @@ int run(int argc, char* argv[])
q_gs_ms_ks
.
GenerateTensorValue
(
GeneratorTensor_1
<
InputDataType
>
{
1
});
k_gs_ns_ks
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
InputDataType
>
{});
v_gs_os_ns
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
InputDataType
>
{});
ygrad_gs_ms_os
.
GenerateTensorValue
(
GeneratorTensor_1
<
InputDataType
>
{
1
});
//dy[g0,g1, m, o]
ygrad_gs_ms_os
.
GenerateTensorValue
(
GeneratorTensor_1
<
InputDataType
>
{
1
});
//
dy[g0,g1, m, o]
d_gs_ms_ns
.
GenerateTensorValue
(
GeneratorTensor_1
<
Acc0BiasDataType
>
{
1
});
// assume mnko = 256
// P = softmax(QK) = 0.0039 * ones
...
...
include/ck/tensor_operation/gpu/grid/gridwise_batched_mha_bwd_xdl_cshuffle_qloop_b2t_v2.hpp
View file @
289e1196
...
...
@@ -122,11 +122,6 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
static
constexpr
auto
B1K0
=
Number
<
Gemm1KPerBlock
/
B1K1Value
>
{};
static
constexpr
auto
B1K1
=
Number
<
B1K1Value
>
{};
// D0
static
constexpr
auto
D0M1
=
Number
<
4
>
{};
static
constexpr
auto
D0M0
=
Number
<
MPerBlock
/
D0M1
.
value
>
{};
// static constexpr auto D0M1 = Number<MPerBlock / MPerXdl>{};
static
constexpr
auto
mfma
=
MfmaSelector
<
GemmDataType
,
MPerXdl
,
NPerXdl
>::
selected_mfma
;
static
constexpr
auto
DropoutNThread
=
mfma
.
num_input_blks
;
// 2
// get_random_8x16() generates 8 random numbers each time
...
...
@@ -1157,21 +1152,25 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
c_block_bytes_end
);
}
// D0
static
constexpr
auto
D0M1
=
Number
<
4
>
{};
static
constexpr
auto
D0M0
=
Number
<
MPerBlock
>
{}
/
D0M1
;
__host__
__device__
static
constexpr
auto
MakeD0GridDescriptor_M0_N0_M1_M2_N1_M3
(
const
D0GridDesc_M_N
&
d0_grid_desc_m_n
)
{
//
const auto M = d0_grid_desc_m_n.GetLength(I0);
//
const auto N = d0_grid_desc_m_n.GetLength(I1);
const
auto
M
=
d0_grid_desc_m_n
.
GetLength
(
I0
);
const
auto
N
=
d0_grid_desc_m_n
.
GetLength
(
I1
);
//
const auto MBlock = M / MPerBlock;
//
const auto NBlock = N / NPerBlock;
const
auto
MBlock
=
M
/
MPerBlock
;
const
auto
NBlock
=
N
/
NPerBlock
;
const
auto
d0_grid_desc_m0_n0_m1_m2_n1_m3
=
transform_tensor_descriptor
(
d0_grid_desc_m_n
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
D0M0
,
D0M1
)),
make_unmerge_transform
(
make_tuple
(
Number
<
NPerBlock
>
{}))),
make_tuple
(
make_unmerge_transform
(
make_tuple
(
MBlock
,
D0M0
,
D0M1
)),
make_unmerge_transform
(
make_tuple
(
NBlock
,
Number
<
NPerBlock
>
{}))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
make_tuple
(
Sequence
<
0
,
2
,
4
>
{},
Sequence
<
1
,
3
>
{}));
return
d0_grid_desc_m0_n0_m1_m2_n1_m3
;
}
...
...
@@ -1184,8 +1183,12 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
__host__
__device__
static
constexpr
auto
GetD0BlockDescriptor_M0_N0_M1_M2_N1_M3
()
{
// B1 matrix in LDS memory, dst of blockwise copy
return
make_naive_tensor_descriptor
(
make_tuple
(
D0M0
,
Number
<
NPerBlock
>
{},
D0M1
),
make_tuple
(
Number
<
NPerBlock
>
{}
*
D0M1
,
D0M1
,
I1
));
return
make_naive_tensor_descriptor
(
make_tuple
(
I1
,
I1
,
D0M0
,
Number
<
NPerBlock
>
{},
D0M1
),
make_tuple
(
Number
<
NPerBlock
>
{}
*
D0M1
,
Number
<
NPerBlock
>
{}
*
D0M1
,
Number
<
NPerBlock
>
{}
*
D0M1
,
D0M1
,
I1
));
}
__host__
__device__
static
constexpr
auto
GetD0BlockReadDescriptor_N0_N1_M0_M1_M2_M3
()
{
...
...
@@ -1215,17 +1218,17 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
tensor_operation
::
element_wise
::
PassThrough
,
tensor_operation
::
element_wise
::
PassThrough
,
InMemoryDataOperationEnum
::
Set
,
Sequence
<
D0M0
,
NPerBlock
,
D0M1
>
,
// BlockSliceLengths
Sequence
<
8
,
32
,
1
>
,
// ThreadClusterLengths
Sequence
<
0
,
2
,
1
>
,
// ThreadClusterArrangeOrder
Sequence
<
I1
,
I1
,
D0M0
,
NPerBlock
,
D0M1
>
,
// BlockSliceLengths
Sequence
<
1
,
1
,
8
,
32
,
1
>
,
// ThreadClusterLengths
Sequence
<
0
,
1
,
2
,
4
,
3
>
,
// ThreadClusterArrangeOrder
D0DataType
,
// SrcData
D0DataType
,
// DstData
D0GridDescriptor_M0_N0_M1_M2_N1_M3
,
// SrcDesc
decltype
(
d0_block_desc_m0_n0_m1_m2_n1_m3
),
// DstDesc
Sequence
<
0
,
2
,
1
>
,
// SrcDimAccessOrder
Sequence
<
1
,
0
,
2
>
,
// DstDimAccessOrder
1
,
// SrcVectorDim
2
,
// DstVectorDim
Sequence
<
0
,
1
,
2
,
4
,
3
>
,
// SrcDimAccessOrder
Sequence
<
0
,
1
,
3
,
2
,
4
>
,
// DstDimAccessOrder
3
,
// SrcVectorDim
4
,
// DstVectorDim
4
,
// SrcScalarPerVector
4
,
// DstScalarPerVector
1
,
...
...
@@ -1242,8 +1245,8 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
Sequence
<
1
,
1
,
8
,
1
,
4
>
,
// SliceLengths
Sequence
<
0
,
1
,
2
,
3
,
4
>
,
// DimAccessOrder
4
,
// SrcVectorDim
4
,
// SrcScalarPerVector
4
>
;
2
,
// SrcScalarPerVector
2
>
;
};
template
<
bool
HasMainKBlockLoop
,
...
...
@@ -1730,17 +1733,6 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
// wave_m_n_id[I0],
// wave_m_n_id[I1]);
// }
// D0
auto
d0_block_copy_global_to_lds
=
typename
D0Loader
::
D0BlockwiseCopy
(
d0_grid_desc_m0_n0_m1_m2_n1_m3
,
make_multi_index
(
0
,
0
,
0
),
tensor_operation
::
element_wise
::
PassThrough
{},
D0Loader
::
d0_block_desc_m0_n0_m1_m2_n1_m3
,
make_multi_index
(
0
,
0
,
0
),
tensor_operation
::
element_wise
::
PassThrough
{});
auto
d0_thread_copy_lds_to_vgpr
=
typename
D0Loader
::
D0ThreadCopy
(
make_tuple
(
wave_id
[
I1
],
wave_m_n_id
[
I1
],
0
,
wave_m_n_id
[
I0
],
0
));
ignore
=
d0_thread_copy_lds_to_vgpr
;
//
// set up Y dot dY
//
...
...
@@ -1833,6 +1825,18 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
// gemm0 M loop
index_t
gemm0_m_block_outer_index
=
num_gemm0_m_block_outer_loop
-
1
;
// D0
auto
d0_block_copy_global_to_lds
=
typename
D0Loader
::
D0BlockwiseCopy
(
d0_grid_desc_m0_n0_m1_m2_n1_m3
,
make_multi_index
(
gemm0_m_block_outer_index
,
block_work_idx_n
,
0
,
0
,
0
),
tensor_operation
::
element_wise
::
PassThrough
{},
D0Loader
::
d0_block_desc_m0_n0_m1_m2_n1_m3
,
make_multi_index
(
0
,
0
,
0
,
0
,
0
),
tensor_operation
::
element_wise
::
PassThrough
{});
auto
d0_thread_copy_lds_to_vgpr
=
typename
D0Loader
::
D0ThreadCopy
(
make_tuple
(
wave_id
[
I1
],
wave_m_n_id
[
I1
],
0
,
wave_m_n_id
[
I0
],
0
));
do
{
auto
m_block_data_idx_on_grid
=
...
...
@@ -2011,7 +2015,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
d0_block_copy_global_to_lds
.
RunRead
(
d0_grid_desc_m0_n0_m1_m2_n1_m3
,
d0_grid_buf
);
// d0_block_copy_global_to_lds.MoveSrcSliceWindow(
// d0_grid_desc_m0_n0_m1_m2_n1_m3, make_multi_index(
0
, 0,
1,
0, 0, 0));
//
d0_grid_desc_m0_n0_m1_m2_n1_m3, make_multi_index(
1
, 0, 0, 0, 0));
d0_block_copy_global_to_lds
.
RunWrite
(
D0Loader
::
d0_block_desc_m0_n0_m1_m2_n1_m3
,
d0_block_buf
);
...
...
@@ -2029,7 +2033,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
// constexpr index_t c_offset =
// c_thread_desc_.CalculateOffset(make_tuple(mr, i));
//if(get_block_1d_id() == 0 && get_thread_local_1d_id() == 0)
//
if(get_block_1d_id() == 0 && get_thread_local_1d_id() == 0)
// if(ck::type_convert<FloatGemmAcc>(d0_thread_buf[i]) != 1.0f)
// {
// float tmp_lds =
...
...
@@ -2049,9 +2053,8 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
});
//});
// d0_block_copy_global_to_lds.MoveSrcSliceWindow(
// d0_grid_desc_m0_n0_m1_m2_n1_m3, make_multi_index(1, 0, -D0M1.value, 0, 0,
// 0));
d0_block_copy_global_to_lds
.
MoveSrcSliceWindow
(
d0_grid_desc_m0_n0_m1_m2_n1_m3
,
make_multi_index
(
-
1
,
0
,
0
,
0
,
0
));
}
// P_i: = softmax(scalar * S_i:)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment