Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
51e102e5
Commit
51e102e5
authored
Jun 19, 2023
by
guangzlu
Browse files
moidfied arg names in bwd qloop
parent
7d6a8ec7
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
12 additions
and
12 deletions
+12
-12
include/ck/tensor_operation/gpu/device/impl/device_grouped_mha_bwd_xdl_cshuffle_qloop_v1.hpp
...ice/impl/device_grouped_mha_bwd_xdl_cshuffle_qloop_v1.hpp
+6
-6
include/ck/tensor_operation/gpu/device/impl/device_grouped_mha_bwd_xdl_cshuffle_qloop_v2.hpp
...ice/impl/device_grouped_mha_bwd_xdl_cshuffle_qloop_v2.hpp
+6
-6
No files found.
include/ck/tensor_operation/gpu/device/impl/device_grouped_mha_bwd_xdl_cshuffle_qloop_v1.hpp
View file @
51e102e5
...
...
@@ -123,7 +123,7 @@ __global__ void
c_element_op
,
arg_ptr
[
group_id
].
a_grid_desc_ak0_m_ak1_
,
arg_ptr
[
group_id
].
b_grid_desc_bk0_n_bk1_
,
arg_ptr
[
group_id
].
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_
n3_n4_n5
_
,
arg_ptr
[
group_id
].
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_
m4_m5_n3
_
,
arg_ptr
[
group_id
].
b1_grid_desc_bk0_n_bk1_
,
arg_ptr
[
group_id
].
y_grid_desc_mblock_mperblock_oblock_operblock_
,
arg_ptr
[
group_id
].
lse_grid_desc_m_
,
...
...
@@ -159,7 +159,7 @@ __global__ void
c_element_op
,
arg_ptr
[
group_id
].
a_grid_desc_ak0_m_ak1_
,
arg_ptr
[
group_id
].
b_grid_desc_bk0_n_bk1_
,
arg_ptr
[
group_id
].
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_
n3_n4_n5
_
,
arg_ptr
[
group_id
].
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_
m4_m5_n3
_
,
arg_ptr
[
group_id
].
b1_grid_desc_bk0_n_bk1_
,
arg_ptr
[
group_id
].
y_grid_desc_mblock_mperblock_oblock_operblock_
,
arg_ptr
[
group_id
].
lse_grid_desc_m_
,
...
...
@@ -661,7 +661,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V1
typename
GridwiseGemm
::
YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock
y_grid_desc_mblock_mperblock_oblock_operblock_
;
typename
GridwiseGemm
::
ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_
n3_n4_n5
_
;
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_
m4_m5_n3
_
;
LSEGridDesc_M
lse_grid_desc_m_
;
KGridDesc_N_K
k_grid_desc_n_k_
;
YGradGridDesc_O0_M_O1
ygrad_grid_desc_o0_m_o1_
;
...
...
@@ -800,7 +800,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V1
typename
GridwiseGemm
::
YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock
y_grid_desc_mblock_mperblock_oblock_operblock
;
typename
GridwiseGemm
::
ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_
n3_n4_n5
;
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_
m4_m5_n3
;
const
index_t
BlockStart
=
grid_size_
;
const
auto
block_2_ctile_map
=
Block2CTileMap
(
k_grid_desc_n_k
,
BlockStart
);
if
(
GridwiseGemm
::
CheckValidity
(
a_grid_desc_ak0_m_ak1
,
...
...
@@ -813,7 +813,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V1
y_grid_desc_m_o
);
}
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_
n3_n4_n5
=
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_
m4_m5_n3
=
GridwiseGemm
::
MakeCGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3
(
z_grid_desc_m_n
);
...
...
@@ -869,7 +869,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V1
b1_grid_desc_bk0_n_bk1
,
y_grid_desc_m_o
,
y_grid_desc_mblock_mperblock_oblock_operblock
,
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_
n3_n4_n5
,
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_
m4_m5_n3
,
lse_grid_desc_m
,
k_grid_desc_n_k
,
ygrad_grid_desc_o0_m_o1
,
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_mha_bwd_xdl_cshuffle_qloop_v2.hpp
View file @
51e102e5
...
...
@@ -123,7 +123,7 @@ __global__ void
c_element_op
,
arg_ptr
[
group_id
].
a_grid_desc_ak0_m_ak1_
,
arg_ptr
[
group_id
].
b_grid_desc_bk0_n_bk1_
,
arg_ptr
[
group_id
].
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_
n3_n4_n5
_
,
arg_ptr
[
group_id
].
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_
m4_m5_n3
_
,
arg_ptr
[
group_id
].
b1_grid_desc_bk0_n_bk1_
,
arg_ptr
[
group_id
].
y_grid_desc_mblock_mperblock_oblock_operblock_
,
arg_ptr
[
group_id
].
lse_grid_desc_m_
,
...
...
@@ -159,7 +159,7 @@ __global__ void
c_element_op
,
arg_ptr
[
group_id
].
a_grid_desc_ak0_m_ak1_
,
arg_ptr
[
group_id
].
b_grid_desc_bk0_n_bk1_
,
arg_ptr
[
group_id
].
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_
n3_n4_n5
_
,
arg_ptr
[
group_id
].
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_
m4_m5_n3
_
,
arg_ptr
[
group_id
].
b1_grid_desc_bk0_n_bk1_
,
arg_ptr
[
group_id
].
y_grid_desc_mblock_mperblock_oblock_operblock_
,
arg_ptr
[
group_id
].
lse_grid_desc_m_
,
...
...
@@ -669,7 +669,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V2
typename
GridwiseGemm
::
YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock
y_grid_desc_mblock_mperblock_oblock_operblock_
;
typename
GridwiseGemm
::
ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_
n3_n4_n5
_
;
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_
m4_m5_n3
_
;
LSEGridDesc_M
lse_grid_desc_m_
;
KGridDesc_N_K
k_grid_desc_n_k_
;
YGradGridDesc_M0_O_M1
ygrad_grid_desc_m0_o_m1_
;
...
...
@@ -808,7 +808,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V2
typename
GridwiseGemm
::
YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock
y_grid_desc_mblock_mperblock_oblock_operblock
;
typename
GridwiseGemm
::
ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_
n3_n4_n5
;
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_
m4_m5_n3
;
const
index_t
BlockStart
=
grid_size_
;
const
auto
block_2_ctile_map
=
Block2CTileMap
(
k_grid_desc_n_k
,
BlockStart
);
if
(
GridwiseGemm
::
CheckValidity
(
a_grid_desc_ak0_m_ak1
,
...
...
@@ -821,7 +821,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V2
y_grid_desc_m_o
);
}
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_
n3_n4_n5
=
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_
m4_m5_n3
=
GridwiseGemm
::
MakeCGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3
(
z_grid_desc_m_n
);
...
...
@@ -877,7 +877,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V2
b1_grid_desc_bk0_n_bk1
,
y_grid_desc_m_o
,
y_grid_desc_mblock_mperblock_oblock_operblock
,
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_
n3_n4_n5
,
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_
m4_m5_n3
,
lse_grid_desc_m
,
k_grid_desc_n_k
,
ygrad_grid_desc_m0_o_m1
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment