Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
0b472e28
"...composable_kernel.git" did not exist on "c8f3acf9c015fbbba11456df5e829e0e7f57eaf2"
Commit
0b472e28
authored
Jul 07, 2023
by
ltqin
Browse files
group remove y_grid_desc_mblock_mperblock_oblock_operblock parameter
parent
15713b20
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
62 additions
and
96 deletions
+62
-96
example/32_batched_gemm_scale_softmax_gemm/grouped_multihead_attention_backward_v3.cpp
..._softmax_gemm/grouped_multihead_attention_backward_v3.cpp
+2
-2
include/ck/tensor_operation/gpu/device/impl/device_batched_mha_bwd_xdl_cshuffle_qloop_light_v1.hpp
...pl/device_batched_mha_bwd_xdl_cshuffle_qloop_light_v1.hpp
+30
-31
include/ck/tensor_operation/gpu/device/impl/device_batched_mha_bwd_xdl_cshuffle_qloop_light_v2.hpp
...pl/device_batched_mha_bwd_xdl_cshuffle_qloop_light_v2.hpp
+30
-31
include/ck/tensor_operation/gpu/device/impl/device_grouped_mha_bwd_xdl_cshuffle_qloop_light_v1.hpp
...pl/device_grouped_mha_bwd_xdl_cshuffle_qloop_light_v1.hpp
+0
-16
include/ck/tensor_operation/gpu/device/impl/device_grouped_mha_bwd_xdl_cshuffle_qloop_light_v2.hpp
...pl/device_grouped_mha_bwd_xdl_cshuffle_qloop_light_v2.hpp
+0
-16
No files found.
example/32_batched_gemm_scale_softmax_gemm/grouped_multihead_attention_backward_v3.cpp
View file @
0b472e28
...
...
@@ -24,7 +24,7 @@ Kernel outputs:
*/
#define USING_MASK 0
#define DIM
128
// DIM should be a multiple of 8.
#define DIM
32
// DIM should be a multiple of 8.
#include <iostream>
#include <numeric>
...
...
@@ -260,7 +260,7 @@ void run_attention_fwd_host(const TensorQ& q_g_m_k,
int
run
(
int
argc
,
char
*
argv
[])
{
bool
do_verification
=
true
;
int
init_method
=
2
;
// method 1 will have slightly higher error; TODO: to investigate
int
init_method
=
1
;
// method 1 will have slightly higher error; TODO: to investigate
bool
time_kernel
=
true
;
// Overall QKV matrices shape
...
...
include/ck/tensor_operation/gpu/device/impl/device_batched_mha_bwd_xdl_cshuffle_qloop_light_v1.hpp
View file @
0b472e28
...
...
@@ -174,37 +174,36 @@ __global__ void
{
for
(
index_t
i
=
0
;
i
<
nblock
;
i
++
)
{
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
>(
p_a_grid
+
a_batch_offset
,
p_b_grid
+
b_batch_offset
,
z_matrix_ptr
,
p_b1_grid
+
b1_batch_offset
,
p_lse_grid
+
lse_batch_offset
,
p_d_grid
+
lse_batch_offset
,
p_ygrad_grid
+
c_batch_offset
,
p_qgrad_grid
+
a_batch_offset
,
p_kgrad_grid
+
b_batch_offset
,
p_vgrad_grid
+
b1_batch_offset
,
p_shared
,
a_element_op
,
b_element_op
,
acc_element_op
,
b1_element_op
,
c_element_op
,
a_grid_desc_ak0_m_ak1
,
b_grid_desc_bk0_n_bk1
,
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
,
b1_grid_desc_bk0_n_bk1
,
lse_grid_desc_m
,
d_grid_desc_m
,
ygrad_grid_desc_o0_m_o1
,
block_2_ctile_map
,
c0_matrix_mask
,
p_drop
,
ph
,
z_random_matrix_offset
,
raw_n_padded
,
i
);
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
>(
p_a_grid
+
a_batch_offset
,
p_b_grid
+
b_batch_offset
,
z_matrix_ptr
,
p_b1_grid
+
b1_batch_offset
,
p_lse_grid
+
lse_batch_offset
,
p_d_grid
+
lse_batch_offset
,
p_ygrad_grid
+
c_batch_offset
,
p_qgrad_grid
+
a_batch_offset
,
p_kgrad_grid
+
b_batch_offset
,
p_vgrad_grid
+
b1_batch_offset
,
p_shared
,
a_element_op
,
b_element_op
,
acc_element_op
,
b1_element_op
,
c_element_op
,
a_grid_desc_ak0_m_ak1
,
b_grid_desc_bk0_n_bk1
,
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
,
b1_grid_desc_bk0_n_bk1
,
lse_grid_desc_m
,
d_grid_desc_m
,
ygrad_grid_desc_o0_m_o1
,
block_2_ctile_map
,
c0_matrix_mask
,
p_drop
,
ph
,
z_random_matrix_offset
,
raw_n_padded
,
i
);
}
}
else
...
...
include/ck/tensor_operation/gpu/device/impl/device_batched_mha_bwd_xdl_cshuffle_qloop_light_v2.hpp
View file @
0b472e28
...
...
@@ -173,37 +173,36 @@ __global__ void
{
for
(
index_t
i
=
0
;
i
<
nblock
;
i
++
)
{
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
>(
p_a_grid
+
a_batch_offset
,
p_b_grid
+
b_batch_offset
,
z_matrix_ptr
,
p_b1_grid
+
b1_batch_offset
,
p_lse_grid
+
lse_batch_offset
,
p_d_grid
+
lse_batch_offset
,
p_ygrad_grid
+
c_batch_offset
,
p_qgrad_grid
+
a_batch_offset
,
p_kgrad_grid
+
b_batch_offset
,
p_vgrad_grid
+
b1_batch_offset
,
p_shared
,
a_element_op
,
b_element_op
,
acc_element_op
,
b1_element_op
,
c_element_op
,
a_grid_desc_ak0_m_ak1
,
b_grid_desc_bk0_n_bk1
,
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3
,
b1_grid_desc_bk0_n_bk1
,
lse_grid_desc_m
,
d_grid_desc_m
,
ygrad_grid_desc_m0_o_m1
,
block_2_ctile_map
,
c0_matrix_mask
,
p_drop
,
ph
,
z_random_matrix_offset
,
raw_n_padded
,
i
);
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
>(
p_a_grid
+
a_batch_offset
,
p_b_grid
+
b_batch_offset
,
z_matrix_ptr
,
p_b1_grid
+
b1_batch_offset
,
p_lse_grid
+
lse_batch_offset
,
p_d_grid
+
lse_batch_offset
,
p_ygrad_grid
+
c_batch_offset
,
p_qgrad_grid
+
a_batch_offset
,
p_kgrad_grid
+
b_batch_offset
,
p_vgrad_grid
+
b1_batch_offset
,
p_shared
,
a_element_op
,
b_element_op
,
acc_element_op
,
b1_element_op
,
c_element_op
,
a_grid_desc_ak0_m_ak1
,
b_grid_desc_bk0_n_bk1
,
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3
,
b1_grid_desc_bk0_n_bk1
,
lse_grid_desc_m
,
d_grid_desc_m
,
ygrad_grid_desc_m0_o_m1
,
block_2_ctile_map
,
c0_matrix_mask
,
p_drop
,
ph
,
z_random_matrix_offset
,
raw_n_padded
,
i
);
}
}
else
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_mha_bwd_xdl_cshuffle_qloop_light_v1.hpp
View file @
0b472e28
...
...
@@ -179,7 +179,6 @@ __global__ void
arg_ptr
[
group_id
].
b_grid_desc_bk0_n_bk1_
,
arg_ptr
[
group_id
].
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3_
,
arg_ptr
[
group_id
].
b1_grid_desc_bk0_n_bk1_
,
arg_ptr
[
group_id
].
y_grid_desc_mblock_mperblock_oblock_operblock_
,
arg_ptr
[
group_id
].
lse_grid_desc_m_
,
arg_ptr
[
group_id
].
lse_grid_desc_m_
,
arg_ptr
[
group_id
].
ygrad_grid_desc_o0_m_o1_
,
...
...
@@ -216,7 +215,6 @@ __global__ void
arg_ptr
[
group_id
].
b_grid_desc_bk0_n_bk1_
,
arg_ptr
[
group_id
].
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3_
,
arg_ptr
[
group_id
].
b1_grid_desc_bk0_n_bk1_
,
arg_ptr
[
group_id
].
y_grid_desc_mblock_mperblock_oblock_operblock_
,
arg_ptr
[
group_id
].
lse_grid_desc_m_
,
arg_ptr
[
group_id
].
lse_grid_desc_m_
,
arg_ptr
[
group_id
].
ygrad_grid_desc_o0_m_o1_
,
...
...
@@ -733,8 +731,6 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_Light_V1
B1GridDesc_BK0_N_BK1
b1_grid_desc_bk0_n_bk1_
;
YGridDesc_M_O
y_grid_desc_m_o_
;
typename
GridwiseGemm
::
YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock
y_grid_desc_mblock_mperblock_oblock_operblock_
;
typename
GridwiseGemm
::
ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3_
;
LSEGridDesc_M
lse_grid_desc_m_
;
...
...
@@ -885,21 +881,10 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_Light_V1
problem_desc
.
b1_gs_gemm1ns_gemm1ks_strides
);
const
auto
c_grid_desc_g_m_n
=
Transform
::
MakeCGridDescriptor_G_M_N
(
problem_desc
.
c_gs_ms_gemm1ns_lengths
,
problem_desc
.
c_gs_ms_gemm1ns_strides
);
typename
GridwiseGemm
::
YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock
y_grid_desc_mblock_mperblock_oblock_operblock
;
typename
GridwiseGemm
::
ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3
;
const
index_t
BlockStart
=
grid_size_
;
const
auto
block_2_ctile_map
=
Block2CTileMap
(
k_grid_desc_n_k
,
BlockStart
);
if
(
GridwiseGemm
::
CheckValidity
(
a_grid_desc_ak0_m_ak1
,
b_grid_desc_bk0_n_bk1
,
b1_grid_desc_bk0_n_bk1
,
y_grid_desc_m_o
))
{
y_grid_desc_mblock_mperblock_oblock_operblock
=
GridwiseGemm
::
MakeYGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock
(
y_grid_desc_m_o
);
}
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3
=
GridwiseGemm
::
MakeCGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3
(
...
...
@@ -973,7 +958,6 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_Light_V1
z_grid_desc_m_n
,
b1_grid_desc_bk0_n_bk1
,
y_grid_desc_m_o
,
y_grid_desc_mblock_mperblock_oblock_operblock
,
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3
,
lse_grid_desc_m
,
k_grid_desc_n_k
,
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_mha_bwd_xdl_cshuffle_qloop_light_v2.hpp
View file @
0b472e28
...
...
@@ -179,7 +179,6 @@ __global__ void
arg_ptr
[
group_id
].
b_grid_desc_bk0_n_bk1_
,
arg_ptr
[
group_id
].
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3_
,
arg_ptr
[
group_id
].
b1_grid_desc_bk0_n_bk1_
,
arg_ptr
[
group_id
].
y_grid_desc_mblock_mperblock_oblock_operblock_
,
arg_ptr
[
group_id
].
lse_grid_desc_m_
,
arg_ptr
[
group_id
].
lse_grid_desc_m_
,
arg_ptr
[
group_id
].
ygrad_grid_desc_m0_o_m1_
,
...
...
@@ -216,7 +215,6 @@ __global__ void
arg_ptr
[
group_id
].
b_grid_desc_bk0_n_bk1_
,
arg_ptr
[
group_id
].
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3_
,
arg_ptr
[
group_id
].
b1_grid_desc_bk0_n_bk1_
,
arg_ptr
[
group_id
].
y_grid_desc_mblock_mperblock_oblock_operblock_
,
arg_ptr
[
group_id
].
lse_grid_desc_m_
,
arg_ptr
[
group_id
].
lse_grid_desc_m_
,
arg_ptr
[
group_id
].
ygrad_grid_desc_m0_o_m1_
,
...
...
@@ -740,8 +738,6 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_Light_V2
B1GridDesc_BK0_N_BK1
b1_grid_desc_bk0_n_bk1_
;
YGridDesc_M_O
y_grid_desc_m_o_
;
typename
GridwiseGemm
::
YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock
y_grid_desc_mblock_mperblock_oblock_operblock_
;
typename
GridwiseGemm
::
ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3_
;
LSEGridDesc_M
lse_grid_desc_m_
;
...
...
@@ -891,21 +887,10 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_Light_V2
problem_desc
.
b1_gs_gemm1ns_gemm1ks_strides
);
const
auto
c_grid_desc_g_m_n
=
Transform
::
MakeCGridDescriptor_G_M_N
(
problem_desc
.
c_gs_ms_gemm1ns_lengths
,
problem_desc
.
c_gs_ms_gemm1ns_strides
);
typename
GridwiseGemm
::
YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock
y_grid_desc_mblock_mperblock_oblock_operblock
;
typename
GridwiseGemm
::
ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3
;
const
index_t
BlockStart
=
grid_size_
;
const
auto
block_2_ctile_map
=
Block2CTileMap
(
k_grid_desc_n_k
,
BlockStart
);
if
(
GridwiseGemm
::
CheckValidity
(
a_grid_desc_ak0_m_ak1
,
b_grid_desc_bk0_n_bk1
,
b1_grid_desc_bk0_n_bk1
,
y_grid_desc_m_o
))
{
y_grid_desc_mblock_mperblock_oblock_operblock
=
GridwiseGemm
::
MakeYGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock
(
y_grid_desc_m_o
);
}
const
index_t
batch_count
=
c_grid_desc_g_m_n
.
GetLength
(
I0
);
const
index_t
grid_size_grp
=
...
...
@@ -975,7 +960,6 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_Light_V2
z_grid_desc_m_n
,
b1_grid_desc_bk0_n_bk1
,
y_grid_desc_m_o
,
y_grid_desc_mblock_mperblock_oblock_operblock
,
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3
,
lse_grid_desc_m
,
k_grid_desc_n_k
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment