Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
763e26be
"...composable_kernel_rocm.git" did not exist on "940949d9d5d2302e414930f6d66aa57f34f2e55f"
Commit
763e26be
authored
Aug 19, 2023
by
letaoqin
Browse files
kernel add all code, need debug
parent
de53e421
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
117 additions
and
26 deletions
+117
-26
example/52_flash_atten_bias/batched_multihead_attention_bias_backward_v2.cpp
...ten_bias/batched_multihead_attention_bias_backward_v2.cpp
+11
-2
include/ck/tensor_operation/gpu/grid/gridwise_batched_mha_bwd_xdl_cshuffle_qloop_b2t_v2.hpp
...id/gridwise_batched_mha_bwd_xdl_cshuffle_qloop_b2t_v2.hpp
+101
-24
include/ck/utility/static_buffer.hpp
include/ck/utility/static_buffer.hpp
+5
-0
No files found.
example/52_flash_atten_bias/batched_multihead_attention_bias_backward_v2.cpp
View file @
763e26be
...
...
@@ -199,6 +199,7 @@ using ReferenceDropoutInstance =
template
<
typename
TensorQ
,
typename
TensorK
,
typename
TensorV
,
typename
TensorD
,
typename
TensorS
,
typename
TensorP
,
typename
TensorZ
,
...
...
@@ -207,6 +208,7 @@ template <typename TensorQ,
void
run_attention_fwd_host
(
const
TensorQ
&
q_g_m_k
,
const
TensorK
&
k_g_n_k
,
const
TensorV
&
v_g_n_o
,
const
TensorD
&
d_g_m_n
,
const
float
alpha
,
TensorS
&
s_g_m_n
,
TensorP
&
p_g_m_n
,
...
...
@@ -226,6 +228,9 @@ void run_attention_fwd_host(const TensorQ& q_g_m_k,
ref_gemm0_invoker
.
Run
(
ref_gemm0_argument
);
// bias
s_g_m_n
.
ForEach
(
[
&
](
auto
&
self
,
auto
idx
)
{
self
(
idx
)
+=
ck
::
type_convert
<
AccDataType
>
(
d_g_m_n
(
idx
));
});
// masking
auto
M
=
s_g_m_n
.
GetLengths
()[
1
];
auto
N
=
s_g_m_n
.
GetLengths
()[
2
];
...
...
@@ -261,7 +266,7 @@ void run_attention_fwd_host(const TensorQ& q_g_m_k,
int
run
(
int
argc
,
char
*
argv
[])
{
bool
do_verification
=
true
;
int
init_method
=
2
;
// method 1 will have slightly higher error; TODO: to investigate
int
init_method
=
1
;
// method 1 will have slightly higher error; TODO: to investigate
bool
time_kernel
=
true
;
// Overall QKV matrices shape
...
...
@@ -409,11 +414,13 @@ int run(int argc, char* argv[])
{
case
0
:
break
;
case
1
:
// q_gs_ms_ks.GenerateTensorValue(GeneratorTensor_1<Acc0BiasDataType>{1});
q_gs_ms_ks
.
GenerateTensorValue
(
GeneratorTensor_2
<
InputDataType
>
{
-
2
,
2
});
k_gs_ns_ks
.
GenerateTensorValue
(
GeneratorTensor_2
<
InputDataType
>
{
-
2
,
2
});
v_gs_os_ns
.
GenerateTensorValue
(
GeneratorTensor_2
<
InputDataType
>
{
-
2
,
2
});
ygrad_gs_ms_os
.
GenerateTensorValue
(
GeneratorTensor_2
<
InputDataType
>
{
-
2
,
2
});
d_gs_ms_ns
.
GenerateTensorValue
(
GeneratorTensor_2
<
Acc0BiasDataType
>
{
-
2
,
2
});
// d_gs_ms_ns.GenerateTensorValue(GeneratorTensor_1<Acc0BiasDataType>{0});
break
;
case
2
:
q_gs_ms_ks
.
GenerateTensorValue
(
GeneratorTensor_3
<
InputDataType
>
{
0.0
,
1.0
});
...
...
@@ -509,6 +516,7 @@ int run(int argc, char* argv[])
q_device_buf
.
ToDevice
(
q_gs_ms_ks
.
mData
.
data
());
k_device_buf
.
ToDevice
(
k_gs_ns_ks
.
mData
.
data
());
d_device_buf
.
ToDevice
(
d_gs_ms_ns
.
mData
.
data
());
z_device_buf
.
ToDevice
(
z_gs_ms_ns
.
mData
.
data
());
v_device_buf
.
ToDevice
(
v_gs_os_ns
.
mData
.
data
());
ygrad_device_buf
.
ToDevice
(
ygrad_gs_ms_os
.
mData
.
data
());
...
...
@@ -611,7 +619,7 @@ int run(int argc, char* argv[])
(
sizeof
(
InputDataType
)
*
M
*
K
+
sizeof
(
InputDataType
)
*
K
*
N
+
sizeof
(
InputDataType
)
*
N
*
O
+
sizeof
(
InputDataType
)
*
M
*
O
*
size_t
(
2
)
+
sizeof
(
OutputDataType
)
*
M
*
K
+
sizeof
(
OutputDataType
)
*
K
*
N
+
sizeof
(
OutputDataType
)
*
N
*
O
)
*
sizeof
(
OutputDataType
)
*
N
*
O
+
sizeof
(
Acc0BiasDataType
)
*
M
*
N
)
*
BatchCount
+
sizeof
(
LSEDataType
)
*
M
*
BatchCount
;
...
...
@@ -635,6 +643,7 @@ int run(int argc, char* argv[])
run_attention_fwd_host
(
q_g_m_k
,
k_g_n_k
,
v_g_n_o
,
d_g_m_n
,
alpha
,
s_g_m_n
,
p_g_m_n
,
...
...
include/ck/tensor_operation/gpu/grid/gridwise_batched_mha_bwd_xdl_cshuffle_qloop_b2t_v2.hpp
View file @
763e26be
...
...
@@ -1181,7 +1181,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
struct
D0Loader
{
__host__
__device__
static
constexpr
auto
GetD0BlockDescriptor_M0_N0_M1_M2_N1_M
()
__host__
__device__
static
constexpr
auto
GetD0BlockDescriptor_M0_N0_M1_M2_N1_M
3
()
{
// B1 matrix in LDS memory, dst of blockwise copy
return
make_naive_tensor_descriptor
(
...
...
@@ -1193,8 +1193,28 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
D0M3
,
I1
));
}
__host__
__device__
static
constexpr
auto
GetD0BlockReadDescriptor_N0_N1_M0_M1_M2_M3
()
{
constexpr
auto
d0_raw_m0_n_m1
=
make_naive_tensor_descriptor
(
make_tuple
(
D0M2
,
Number
<
NPerBlock
>
{},
D0M3
),
make_tuple
(
Number
<
NPerBlock
>
{}
*
D0M3
,
D0M3
,
I1
));
constexpr
auto
d0_n0_n1_m0_m1_m2_m3
=
transform_tensor_descriptor
(
d0_raw_m0_n_m1
,
make_tuple
(
make_unmerge_transform
(
make_tuple
((
D0M2
*
D0M3
)
/
I8
,
I2
,
I4
/
D0M3
)),
make_unmerge_transform
(
make_tuple
(
Number
<
NPerBlock
/
NPerXdl
>
{},
Number
<
NPerXdl
>
{})),
make_pass_through_transform
(
D0M3
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}),
make_tuple
(
Sequence
<
2
,
3
,
4
>
{},
Sequence
<
0
,
1
>
{},
Sequence
<
5
>
{}));
return
d0_n0_n1_m0_m1_m2_m3
;
}
static
constexpr
auto
d0_block_desc_m0_n0_m1_m2_n1_m3
=
GetD0BlockDescriptor_M0_N0_M1_M2_N1_M
();
GetD0BlockDescriptor_M0_N0_M1_M2_N1_M3
();
static
constexpr
auto
d0_block_desc_n0_n1_m0_m1_m2_m3
=
GetD0BlockReadDescriptor_N0_N1_M0_M1_M2_M3
();
static
constexpr
auto
d0_thread_desc_
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
I1
,
I1
,
I4
,
I1
,
I4
/
D0M3
,
D0M3
));
using
D0BlockwiseCopy
=
ThreadGroupTensorSliceTransfer_v4r1
<
ThisThreadBlock
,
...
...
@@ -1219,6 +1239,17 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
false
,
true
,
// DstResetCoord
1
>
;
using
D0ThreadCopy
=
ThreadwiseTensorSliceTransfer_v4
<
D0DataType
,
// SrcData
D0DataType
,
// DstData
decltype
(
d0_block_desc_n0_n1_m0_m1_m2_m3
),
// SrcDesc
decltype
(
d0_thread_desc_
),
// DstDesc
Sequence
<
1
,
1
,
4
,
1
,
2
,
2
>
,
// SliceLengths
Sequence
<
0
,
1
,
2
,
3
,
4
,
5
>
,
// DimAccessOrder
5
,
// SrcVectorDim
D0M3
.
value
,
// SrcScalarPerVector
D0M3
.
value
>
;
};
template
<
bool
HasMainKBlockLoop
,
...
...
@@ -1546,14 +1577,6 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
qgrad_thread_origin_on_grid_m0_o0_m1_o1_m2_o2_o3_o4
,
scale_rp_dropout
);
// D0
auto
d0_block_copy_global_to_lds
=
typename
D0Loader
::
D0BlockwiseCopy
(
d0_grid_desc_m0_n0_m1_m2_n1_m3
,
make_multi_index
(
0
,
block_work_idx_n
,
0
,
0
,
0
,
0
),
tensor_operation
::
element_wise
::
PassThrough
{},
D0Loader
::
d0_block_desc_m0_n0_m1_m2_n1_m3
,
make_multi_index
(
0
,
0
,
0
,
0
,
0
,
0
),
tensor_operation
::
element_wise
::
PassThrough
{});
//
// Blockwise softmax
//
...
...
@@ -1703,7 +1726,27 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
0
,
//
wave_m_n_id
[
I1
]),
// NPerXdl
tensor_operation
::
element_wise
::
PassThrough
{}};
// if(get_block_1d_id() == 0 && get_thread_local_1d_id() == 46)
// {
// printf("get_thread_local_1d_id(): %d, wave_id[I0]: %d wave_id[I1]: %d "
// "wave_m_n_id[I0]: %d wave_m_n_id[I1]: %d \n",
// get_thread_local_1d_id(),
// wave_id[I0],
// wave_id[I1],
// wave_m_n_id[I0],
// wave_m_n_id[I1]);
// }
// D0
auto
d0_block_copy_global_to_lds
=
typename
D0Loader
::
D0BlockwiseCopy
(
d0_grid_desc_m0_n0_m1_m2_n1_m3
,
make_multi_index
(
0
,
block_work_idx_n
,
0
,
0
,
0
,
0
),
tensor_operation
::
element_wise
::
PassThrough
{},
D0Loader
::
d0_block_desc_m0_n0_m1_m2_n1_m3
,
make_multi_index
(
0
,
0
,
0
,
0
,
0
,
0
),
tensor_operation
::
element_wise
::
PassThrough
{});
auto
d0_thread_copy_lds_to_vgpr
=
typename
D0Loader
::
D0ThreadCopy
(
make_tuple
(
wave_id
[
I1
],
wave_m_n_id
[
I1
],
0
,
wave_m_n_id
[
I0
],
0
,
0
));
ignore
=
d0_thread_copy_lds_to_vgpr
;
//
// set up Y dot dY
//
...
...
@@ -1940,6 +1983,9 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
// add bias
if
constexpr
(
!
is_same
<
D0DataType
,
void
>::
value
)
{
static
constexpr
auto
c_thread_desc_
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
D0M1
,
Number
<
16
>
{}));
const
auto
d0_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_d0_grid
,
d0_grid_desc_m0_n0_m1_m2_n1_m3
.
GetElementSpaceSize
());
...
...
@@ -1947,15 +1993,47 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
static_cast
<
GemmDataType
*>
(
p_shared
)
+
SharedMemTrait
::
a_block_space_offset
,
D0Loader
::
d0_block_desc_m0_n0_m1_m2_n1_m3
.
GetElementSpaceSize
());
static_for
<
0
,
D0M1
,
1
>
{}([
&
](
auto
)
{
auto
d0_thread_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
D0DataType
>
(
D0Loader
::
d0_thread_desc_
.
GetElementSpaceSize
());
ignore
=
d0_thread_buf
;
static_for
<
0
,
D0M1
,
1
>
{}([
&
](
auto
mr
)
{
// load data to lds
d0_block_copy_global_to_lds
.
RunRead
(
d0_grid_desc_m0_n0_m1_m2_n1_m3
,
d0_grid_buf
);
d0_block_copy_global_to_lds
.
MoveSrcSliceWindow
(
d0_grid_desc_m0_n0_m1_m2_n1_m3
,
make_multi_index
(
0
,
0
,
1
,
0
,
0
,
0
));
d0_block_copy_global_to_lds
.
RunWrite
(
D0Loader
::
d0_block_desc_m0_n0_m1_m2_n1_m3
,
d0_block_buf
);
block_sync_lds
();
// read data form lds
d0_thread_copy_lds_to_vgpr
.
Run
(
D0Loader
::
d0_block_desc_n0_n1_m0_m1_m2_m3
,
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
,
I0
),
d0_block_buf
,
D0Loader
::
d0_thread_desc_
,
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
,
I0
),
d0_thread_buf
);
// bias add
static_for
<
0
,
D0Loader
::
d0_thread_desc_
.
GetElementSpaceSize
(),
1
>
{}(
[
&
](
auto
i
)
{
constexpr
index_t
c_offset
=
c_thread_desc_
.
CalculateOffset
(
make_tuple
(
mr
,
i
));
s_slash_p_thread_buf
(
Number
<
c_offset
>
{})
+=
ck
::
type_convert
<
FloatGemmAcc
>
(
d0_thread_buf
[
i
]);
// if(get_block_1d_id() == 0 && get_thread_local_1d_id() == 0)
// {
// printf("c_offset: %d s_slash_p_thread_buf(Number<c_offset>{}):
// %f, "
// "d0_thread_buf[i]: %f\n",
// c_offset,
// s_slash_p_thread_buf(Number<c_offset>{}),
// ck::type_convert<FloatGemmAcc>(d0_thread_buf[i]));
// }
});
});
d0_block_copy_global_to_lds
.
MoveSrcSliceWindow
(
...
...
@@ -2036,11 +2114,11 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
// dV = P_drop^T * dY
{
// TODO: explore using dynamic buffer for a1 thread buffer
// For a1_blockwise_copy, the goal is to satisfy pipeline requirements
RunRead(),
// RunWrite(), and MoveSliceWindow(). But it is impossible to
implement given that
// the A1 source buffer is static buffer holding the output
of first GEMM and
// requires constexpr offset by design. Therefore, we pass
tensor coordinate offset
// explicitly in Run() below.
// For a1_blockwise_copy, the goal is to satisfy pipeline requirements
//
RunRead(),
RunWrite(), and MoveSliceWindow(). But it is impossible to
//
implement given that
the A1 source buffer is static buffer holding the output
//
of first GEMM and
requires constexpr offset by design. Therefore, we pass
//
tensor coordinate offset
explicitly in Run() below.
// preload data into LDS
vgrad_gemm_tile_ygrad_blockwise_copy
.
RunRead
(
ygrad_grid_desc_m0_o_m1
,
...
...
@@ -2196,11 +2274,11 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
// dK = scalar * dS^T * Q
{
// TODO: explore using dynamic buffer for a1 thread buffer
// For a1_blockwise_copy, the goal is to satisfy pipeline requirements
RunRead(),
// RunWrite(), and MoveSliceWindow(). But it is impossible to
implement given that
// the A1 source buffer is static buffer holding the output
of first GEMM and
// requires constexpr offset by design. Therefore, we pass
tensor coordinate offset
// explicitly in Run() below.
// For a1_blockwise_copy, the goal is to satisfy pipeline requirements
//
RunRead(),
RunWrite(), and MoveSliceWindow(). But it is impossible to
//
implement given that
the A1 source buffer is static buffer holding the output
//
of first GEMM and
requires constexpr offset by design. Therefore, we pass
//
tensor coordinate offset
explicitly in Run() below.
// preload data into LDS
kgrad_gemm_tile_q_blockwise_copy
.
RunRead
(
q_grid_desc_m0_k_m1
,
q_grid_buf
);
...
...
@@ -2286,7 +2364,6 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
z_thread_copy_vgpr_to_global
.
MoveDstSliceWindow
(
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3
,
make_multi_index
(
-
1
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
));
}
while
(
0
<
gemm0_m_block_outer_index
--
);
// end j loop
// shuffle dK&dV and write
...
...
include/ck/utility/static_buffer.hpp
View file @
763e26be
...
...
@@ -111,6 +111,11 @@ struct StaticBufferTupleOfVector
return
base
::
operator
()(
i_v
).
template
AsType
<
S
>()(
i_s
);
}
template
<
index_t
I
>
__host__
__device__
constexpr
S
&
operator
()(
Number
<
I
>
i_v
,
Number
<
I
>
i_s
)
{
return
base
::
operator
()(
i_v
).
template
AsType
<
S
>()(
i_s
);
}
// Get X
// i is offset of S, not X. i should be aligned to X
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment