Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
67e10a6a
"vscode:/vscode.git/clone" did not exist on "445de55044a693002f4cc0152ad3a57aa209d742"
Commit
67e10a6a
authored
Aug 28, 2023
by
letaoqin
Browse files
v1 group finished
parent
8efd67d8
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
77 additions
and
20 deletions
+77
-20
example/52_flash_atten_bias/grouped_multihead_attention_bias_backward_v2.cpp
...ten_bias/grouped_multihead_attention_bias_backward_v2.cpp
+1
-1
include/ck/tensor_operation/gpu/device/impl/device_grouped_mha_bwd_xdl_cshuffle_qloop_v1.hpp
...ice/impl/device_grouped_mha_bwd_xdl_cshuffle_qloop_v1.hpp
+75
-17
include/ck/tensor_operation/gpu/device/impl/device_grouped_mha_bwd_xdl_cshuffle_qloop_v2.hpp
...ice/impl/device_grouped_mha_bwd_xdl_cshuffle_qloop_v2.hpp
+1
-2
No files found.
example/52_flash_atten_bias/grouped_multihead_attention_bias_backward_v2.cpp
View file @
67e10a6a
...
...
@@ -24,7 +24,7 @@ Kernel outputs:
*/
#define USING_MASK 0
#define DIM
128
// DIM should be a multiple of 8.
#define DIM
64
// DIM should be a multiple of 8.
#include <iostream>
#include <numeric>
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_mha_bwd_xdl_cshuffle_qloop_v1.hpp
View file @
67e10a6a
...
...
@@ -134,6 +134,7 @@ __global__ void
c_element_op
,
arg_ptr
[
group_id
].
a_grid_desc_ak0_m_ak1_
,
arg_ptr
[
group_id
].
b_grid_desc_bk0_n_bk1_
,
arg_ptr
[
group_id
].
d0_grid_desc_m0_n0_m1_m2_n1_m3_
,
arg_ptr
[
group_id
].
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3_
,
arg_ptr
[
group_id
].
b1_grid_desc_bk0_n_bk1_
,
arg_ptr
[
group_id
].
y_grid_desc_mblock_mperblock_oblock_operblock_
,
...
...
@@ -171,6 +172,7 @@ __global__ void
c_element_op
,
arg_ptr
[
group_id
].
a_grid_desc_ak0_m_ak1_
,
arg_ptr
[
group_id
].
b_grid_desc_bk0_n_bk1_
,
arg_ptr
[
group_id
].
d0_grid_desc_m0_n0_m1_m2_n1_m3_
,
arg_ptr
[
group_id
].
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3_
,
arg_ptr
[
group_id
].
b1_grid_desc_bk0_n_bk1_
,
arg_ptr
[
group_id
].
y_grid_desc_mblock_mperblock_oblock_operblock_
,
...
...
@@ -297,11 +299,11 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
std
::
vector
<
index_t
>
lse_gs_ms_lengths
;
std
::
vector
<
index_t
>
lse_gs_ms_strides
;
std
::
vector
<
std
::
vector
<
index_t
>
>
acc0_biases_gs_ms_ns_lengths
;
std
::
vector
<
std
::
vector
<
index_t
>
>
acc0_biases_gs_ms_ns_strides
;
std
::
vector
<
index_t
>
acc0_biases_gs_ms_ns_lengths
;
std
::
vector
<
index_t
>
acc0_biases_gs_ms_ns_strides
;
std
::
vector
<
std
::
vector
<
index_t
>
>
acc1_biases_gs_ms_os_lengths
;
std
::
vector
<
std
::
vector
<
index_t
>
>
acc1_biases_gs_ms_os_strides
;
std
::
vector
<
index_t
>
acc1_biases_gs_ms_os_lengths
;
std
::
vector
<
index_t
>
acc1_biases_gs_ms_os_strides
;
};
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
...
...
@@ -495,10 +497,22 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
}
}
// D in Gemm0 C position
static
auto
MakeDGridDescriptor_M_N
(
const
std
::
vector
<
index_t
>&
d_gs_ms_ns_lengths_vec
,
const
std
::
vector
<
index_t
>&
d_gs_ms_ns_strides_vec
)
static
auto
MakeD0GridDescriptor_M_N
(
const
std
::
vector
<
ck
::
index_t
>&
acc0_biases_gs_ms_ns_lengths
,
const
std
::
vector
<
ck
::
index_t
>&
acc0_biases_gs_ms_ns_strides
)
{
return
Transform
::
MakeCGridDescriptor_M_N
(
d_gs_ms_ns_lengths_vec
,
d_gs_ms_ns_strides_vec
);
return
Transform
::
MakeCGridDescriptor_M_N
(
acc0_biases_gs_ms_ns_lengths
,
acc0_biases_gs_ms_ns_strides
);
}
static
auto
MakeD0GridDescriptor_G_M_N
(
const
std
::
vector
<
ck
::
index_t
>&
acc0_biases_gs_ms_ns_lengths
,
const
std
::
vector
<
ck
::
index_t
>&
acc0_biases_gs_ms_ns_strides
)
{
return
Transform
::
MakeCGridDescriptor_G_M_N
(
acc0_biases_gs_ms_ns_lengths
,
acc0_biases_gs_ms_ns_strides
);
}
using
AGridDesc_AK0_M_AK1
=
decltype
(
MakeAGridDescriptor_AK0_M_AK1
({},
{}));
...
...
@@ -508,12 +522,13 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
using
LSEGridDesc_M
=
decltype
(
MakeLSEGridDescriptor_M
(
1
));
using
AGridDesc_G_M_K
=
decltype
(
Transform
::
MakeAGridDescriptor_G_M_K
({},
{}));
using
BGridDesc_G_N_K
=
decltype
(
Transform
::
MakeB0GridDescriptor_G_N_K
({},
{}));
using
D0GridDesc_G_M_N
=
decltype
(
MakeD0GridDescriptor_G_M_N
({},
{}));
using
B1GridDesc_G_N_K
=
decltype
(
Transform
::
MakeB1GridDescriptor_G_N_K
({},
{}));
using
CGridDesc_G_M_N
=
decltype
(
Transform
::
MakeCGridDescriptor_G_M_N
({},
{}));
using
ZGridDesc_G_M_N
=
decltype
(
Transform
::
MakeCGridDescriptor_G_M_N
({},
{}));
using
KGridDesc_N_K
=
decltype
(
Transform
::
MakeB0GridDescriptor_N_K
({},
{}));
using
D0GridDesc_M_N
=
decltype
(
MakeDGridDescriptor_M_N
({},
{}));
using
D0GridDesc_M_N
=
decltype
(
MakeD
0
GridDescriptor_M_N
({},
{}));
using
YGradGridDesc_O0_M_O1
=
decltype
(
MakeYGradGridDescriptor_O0_M_O1
({},
{}));
using
ZGridDesc_M_N
=
decltype
(
MakeZGridDescriptor_M_N
({},
{}));
...
...
@@ -538,12 +553,14 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
{
ComputeBasePtrOfStridedBatch
(
const
AGridDesc_G_M_K
&
a_grid_desc_g_m_k
,
const
BGridDesc_G_N_K
&
b_grid_desc_g_n_k
,
const
D0GridDesc_G_M_N
&
d0_grid_desc_g_m_n
,
const
ZGridDesc_G_M_N
&
z_grid_desc_g_m_n
,
const
B1GridDesc_G_N_K
&
b1_grid_desc_g_n_k
,
const
CGridDesc_G_M_N
&
c_grid_desc_g_m_n
,
index_t
batch_stride_lse
)
:
a_grid_desc_g_m_k_
(
a_grid_desc_g_m_k
),
b_grid_desc_g_n_k_
(
b_grid_desc_g_n_k
),
d0_grid_desc_g_m_n_
(
d0_grid_desc_g_m_n
),
z_grid_desc_g_m_n_
(
z_grid_desc_g_m_n
),
b1_grid_desc_g_n_k_
(
b1_grid_desc_g_n_k
),
c_grid_desc_g_m_n_
(
c_grid_desc_g_m_n
),
...
...
@@ -561,6 +578,11 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
return
b_grid_desc_g_n_k_
.
CalculateOffset
(
make_multi_index
(
g_idx
,
0
,
0
));
}
__host__
__device__
constexpr
long_index_t
GetD0BasePtr
(
index_t
g_idx
)
const
{
return
d0_grid_desc_g_m_n_
.
CalculateOffset
(
make_multi_index
(
g_idx
,
0
,
0
));
}
__host__
__device__
constexpr
long_index_t
GetZBasePtr
(
index_t
g_idx
)
const
{
return
z_grid_desc_g_m_n_
.
CalculateOffset
(
make_multi_index
(
g_idx
,
0
,
0
));
...
...
@@ -584,6 +606,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
private:
AGridDesc_G_M_K
a_grid_desc_g_m_k_
;
BGridDesc_G_N_K
b_grid_desc_g_n_k_
;
D0GridDesc_G_M_N
d0_grid_desc_g_m_n_
;
ZGridDesc_G_M_N
z_grid_desc_g_m_n_
;
B1GridDesc_G_N_K
b1_grid_desc_g_n_k_
;
CGridDesc_G_M_N
c_grid_desc_g_m_n_
;
...
...
@@ -663,6 +686,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
// pointers
const
InputDataType
*
p_a_grid_
;
const
InputDataType
*
p_b_grid_
;
const
D0DataType
*
p_d0_grid_
;
ZDataType
*
p_z_grid_
;
const
InputDataType
*
p_b1_grid_
;
const
InputDataType
*
p_c_grid_
;
...
...
@@ -675,6 +699,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
// tensor descriptors for block/thread-wise copy
AGridDesc_AK0_M_AK1
a_grid_desc_ak0_m_ak1_
;
BGridDesc_BK0_N_BK1
b_grid_desc_bk0_n_bk1_
;
typename
GridwiseGemm
::
D0GridDescriptor_M0_N0_M1_M2_N1_M3
d0_grid_desc_m0_n0_m1_m2_n1_m3_
;
ZGridDesc_M_N
z_grid_desc_m_n_
;
B1GridDesc_BK0_N_BK1
b1_grid_desc_bk0_n_bk1_
;
YGridDesc_M_O
y_grid_desc_m_o_
;
...
...
@@ -714,6 +739,9 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
CGridDesc_G_M_N
c_grid_desc_g_m_n_
;
index_t
batch_count_
;
// raw data
std
::
vector
<
ck
::
index_t
>
d0_n_length_stride_
;
};
// Argument
struct
Argument
:
public
BaseArgument
...
...
@@ -759,16 +787,14 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
group_count_
==
ck
::
type_convert
<
ck
::
index_t
>
(
p_Qgrads
.
size
())
&&
group_count_
==
ck
::
type_convert
<
ck
::
index_t
>
(
p_Kgrads
.
size
())
&&
group_count_
==
ck
::
type_convert
<
ck
::
index_t
>
(
p_Vgrads
.
size
())
&&
group_count_
==
ck
::
type_convert
<
ck
::
index_t
>
(
p_LSEs
.
size
())))
group_count_
==
ck
::
type_convert
<
ck
::
index_t
>
(
p_LSEs
.
size
())
&&
(
group_count_
==
ck
::
type_convert
<
ck
::
index_t
>
(
p_acc0_biases
.
size
())
||
ck
::
type_convert
<
ck
::
index_t
>
(
p_acc0_biases
.
size
()
==
0
))
&&
0
==
p_acc1_biases
.
size
()))
{
throw
std
::
runtime_error
(
"wrong! group_count_ != p_As/b/b1/c.size"
);
}
if
(
!
(
p_acc0_biases
.
size
()
==
p_acc1_biases
.
size
()))
{
throw
std
::
runtime_error
(
"wrong! acc0_bias_vec.size != acc1_bias_vec.size"
);
}
grid_size_
=
0
;
index_t
z_random_matrix_offset
=
0
;
...
...
@@ -777,6 +803,10 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
{
const
auto
p_a_grid
=
static_cast
<
const
InputDataType
*>
(
p_As
[
i
]);
const
auto
p_b_grid
=
static_cast
<
const
InputDataType
*>
(
p_Bs
[
i
]);
const
auto
p_d0_grid
=
(
ck
::
type_convert
<
ck
::
index_t
>
(
p_acc0_biases
.
size
())
==
group_count_
)
?
static_cast
<
const
D0DataType
*>
(
p_acc0_biases
[
i
])
:
nullptr
;
auto
p_z_grid
=
static_cast
<
ZDataType
*>
(
p_Zs
[
i
]);
const
auto
p_b1_grid
=
static_cast
<
const
InputDataType
*>
(
p_B1s
[
i
]);
const
auto
p_c_grid
=
static_cast
<
const
InputDataType
*>
(
p_Cs
[
i
]);
...
...
@@ -792,6 +822,23 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
problem_desc
.
a_gs_ms_ks_lengths
,
problem_desc
.
a_gs_ms_ks_strides
);
const
auto
b_grid_desc_bk0_n_bk1
=
DeviceOp
::
MakeBGridDescriptor_BK0_N_BK1
(
problem_desc
.
b_gs_ns_ks_lengths
,
problem_desc
.
b_gs_ns_ks_strides
);
std
::
vector
<
index_t
>
tmp_d0_gs_ms_ns_lengths
;
std
::
vector
<
index_t
>
tmp_d0_gs_ms_ns_strides
;
if
constexpr
(
!
is_same
<
D0DataType
,
void
>::
value
)
{
tmp_d0_gs_ms_ns_lengths
=
problem_desc
.
acc0_biases_gs_ms_ns_lengths
;
tmp_d0_gs_ms_ns_strides
=
problem_desc
.
acc0_biases_gs_ms_ns_strides
;
}
else
{
tmp_d0_gs_ms_ns_lengths
=
{
1
,
1
,
1
,
1
};
tmp_d0_gs_ms_ns_strides
=
{
0
,
0
,
0
,
0
};
}
const
D0GridDesc_M_N
d0_grid_desc_m_n
{
DeviceOp
::
MakeD0GridDescriptor_M_N
(
tmp_d0_gs_ms_ns_lengths
,
tmp_d0_gs_ms_ns_strides
)};
const
auto
d0_grid_desc_m0_n0_m1_m2_n1_m3
=
GridwiseGemm
::
MakeD0GridDescriptor_M0_N0_M1_M2_N1_M3
(
d0_grid_desc_m_n
);
const
auto
z_grid_desc_m_n
=
DeviceOp
::
MakeZGridDescriptor_M_N
(
problem_desc
.
z_gs_ms_ns_lengths
,
problem_desc
.
z_gs_ms_ns_strides
);
const
auto
b1_grid_desc_bk0_n_bk1
=
DeviceOp
::
MakeVGridDescriptor_O0_N_O1
(
...
...
@@ -811,6 +858,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
problem_desc
.
a_gs_ms_ks_lengths
,
problem_desc
.
a_gs_ms_ks_strides
);
const
auto
b_grid_desc_g_n_k
=
Transform
::
MakeB0GridDescriptor_G_N_K
(
problem_desc
.
b_gs_ns_ks_lengths
,
problem_desc
.
b_gs_ns_ks_strides
);
const
auto
d0_grid_desc_g_m_n
=
DeviceOp
::
MakeD0GridDescriptor_G_M_N
(
tmp_d0_gs_ms_ns_lengths
,
tmp_d0_gs_ms_ns_strides
);
const
auto
z_grid_desc_g_m_n
=
Transform
::
MakeCGridDescriptor_G_M_N
(
problem_desc
.
z_gs_ms_ns_lengths
,
problem_desc
.
z_gs_ms_ns_strides
);
const
auto
b1_grid_desc_g_n_k
=
Transform
::
MakeB1GridDescriptor_G_N_K
(
...
...
@@ -847,6 +896,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
const
auto
compute_base_ptr_of_batch
=
ComputeBasePtrOfStridedBatch
(
a_grid_desc_g_m_k
,
b_grid_desc_g_n_k
,
d0_grid_desc_g_m_n
,
z_grid_desc_g_m_n
,
b1_grid_desc_g_n_k
,
c_grid_desc_g_m_n
,
...
...
@@ -865,6 +915,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
group_kernel_args_
.
push_back
({
p_a_grid
,
p_b_grid
,
p_d0_grid
,
p_z_grid
,
p_b1_grid
,
p_c_grid
,
...
...
@@ -875,6 +926,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
p_vgrad_grid
,
a_grid_desc_ak0_m_ak1
,
b_grid_desc_bk0_n_bk1
,
d0_grid_desc_m0_n0_m1_m2_n1_m3
,
z_grid_desc_m_n
,
b1_grid_desc_bk0_n_bk1
,
y_grid_desc_m_o
,
...
...
@@ -896,6 +948,11 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
z_random_matrix_offset
=
z_random_matrix_offset
+
raw_m_padded
*
raw_n_padded
*
batch_count
;
// for check
std
::
vector
<
ck
::
index_t
>
d0_n_length_stride
;
d0_n_length_stride
.
push_back
(
tmp_d0_gs_ms_ns_lengths
[
NumDimG
+
NumDimM
]);
d0_n_length_stride
.
push_back
(
tmp_d0_gs_ms_ns_strides
[
NumDimG
+
NumDimM
]);
group_device_args_
.
push_back
(
{{
problem_desc
.
a_gs_ms_ks_lengths
[
NumDimG
+
NumDimM
-
1
],
problem_desc
.
b_gs_ns_ks_lengths
[
NumDimG
+
NumDimN
-
1
],
...
...
@@ -910,7 +967,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
{
problem_desc
.
c_gs_ms_gemm1ns_strides
[
NumDimG
+
NumDimM
-
1
],
problem_desc
.
c_gs_ms_gemm1ns_strides
[
NumDimG
+
NumDimM
+
NumDimO
-
1
]},
c_grid_desc_g_m_n
,
batch_count
});
batch_count
,
d0_n_length_stride
});
}
// TODO: implement bias addition
// ignore = p_acc0_biases;
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_mha_bwd_xdl_cshuffle_qloop_v2.hpp
View file @
67e10a6a
...
...
@@ -708,7 +708,6 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
AGridDesc_AK0_M_AK1
a_grid_desc_ak0_m_ak1_
;
BGridDesc_BK0_N_BK1
b_grid_desc_bk0_n_bk1_
;
typename
GridwiseGemm
::
D0GridDescriptor_M0_N0_M1_M2_N1_M3
d0_grid_desc_m0_n0_m1_m2_n1_m3_
;
;
ZGridDesc_M_N
z_grid_desc_m_n_
;
B1GridDesc_BK0_N_BK1
b1_grid_desc_bk0_n_bk1_
;
YGridDesc_M_O
y_grid_desc_m_o_
;
...
...
@@ -835,9 +834,9 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
problem_desc
.
a_gs_ms_ks_lengths
,
problem_desc
.
a_gs_ms_ks_strides
);
const
auto
b_grid_desc_bk0_n_bk1
=
DeviceOp
::
MakeBGridDescriptor_BK0_N_BK1
(
problem_desc
.
b_gs_ns_ks_lengths
,
problem_desc
.
b_gs_ns_ks_strides
);
std
::
vector
<
index_t
>
tmp_d0_gs_ms_ns_lengths
;
std
::
vector
<
index_t
>
tmp_d0_gs_ms_ns_strides
;
if
constexpr
(
!
is_same
<
D0DataType
,
void
>::
value
)
{
tmp_d0_gs_ms_ns_lengths
=
problem_desc
.
acc0_biases_gs_ms_ns_lengths
;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment