Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
b41172a3
Commit
b41172a3
authored
Oct 08, 2023
by
Qianfeng Zhang
Browse files
Some tiny updates
parent
b23b3d71
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
14 additions
and
23 deletions
+14
-23
include/ck/tensor_operation/gpu/device/impl/device_grouped_mha_bwd_xdl_cshuffle_qloop_v1.hpp
...ice/impl/device_grouped_mha_bwd_xdl_cshuffle_qloop_v1.hpp
+1
-1
include/ck/tensor_operation/gpu/device/impl/device_grouped_mha_bwd_xdl_cshuffle_qloop_v2.hpp
...ice/impl/device_grouped_mha_bwd_xdl_cshuffle_qloop_v2.hpp
+1
-14
include/ck/tensor_operation/gpu/device/impl/device_grouped_mha_fwd_xdl_cshuffle_v1.hpp
...pu/device/impl/device_grouped_mha_fwd_xdl_cshuffle_v1.hpp
+1
-1
include/ck/tensor_operation/gpu/device/impl/device_grouped_mha_fwd_xdl_cshuffle_v2.hpp
...pu/device/impl/device_grouped_mha_fwd_xdl_cshuffle_v2.hpp
+1
-1
include/ck/tensor_operation/gpu/grid/gridwise_batched_mha_bwd_xdl_cshuffle_qloop_b2t_v2.hpp
...id/gridwise_batched_mha_bwd_xdl_cshuffle_qloop_b2t_v2.hpp
+10
-6
No files found.
include/ck/tensor_operation/gpu/device/impl/device_grouped_mha_bwd_xdl_cshuffle_qloop_v1.hpp
View file @
b41172a3
...
...
@@ -918,7 +918,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
z_grid_desc_g_m_n
,
b1_grid_desc_g_n_k
,
c_grid_desc_g_m_n
,
type_convert
<
index_t
>
(
lse_grid_desc_m
.
GetElementSpaceSize
()
));
type_convert
<
index_t
>
(
problem_desc
.
lse_gs_ms_strides
[
NumDimG
-
1
]
));
// C0 mask
const
auto
c0_matrix_mask
=
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_mha_bwd_xdl_cshuffle_qloop_v2.hpp
View file @
b41172a3
...
...
@@ -448,19 +448,6 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
}
//
// dP = dY * V^T
//
// YGrad in Gemm A position
static
auto
MakeYGradGridDescriptor_O0_M_O1
(
const
std
::
vector
<
index_t
>&
y_gs_ms_os_lengths
,
const
std
::
vector
<
index_t
>&
y_gs_ms_os_strides
)
{
return
Transform
::
MakeAGridDescriptor_AK0_M_AK1
(
Transform
::
MakeAGridDescriptor_M_K
(
y_gs_ms_os_lengths
,
y_gs_ms_os_strides
),
Number
<
Y_O1
>
{});
}
// V in Gemm B position
static
auto
MakeVGridDescriptor_O0_N_O1
(
const
std
::
vector
<
index_t
>&
v_gs_os_ns_lengths
,
const
std
::
vector
<
index_t
>&
v_gs_os_ns_strides
)
...
...
@@ -988,7 +975,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
z_grid_desc_g_m_n
,
b1_grid_desc_g_n_k
,
c_grid_desc_g_m_n
,
type_convert
<
index_t
>
(
lse_grid_desc_m
.
GetElementSpaceSize
()
));
type_convert
<
index_t
>
(
problem_desc
.
lse_gs_ms_strides
[
NumDimG
-
1
]
));
// C0 mask
const
auto
c0_matrix_mask
=
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_mha_fwd_xdl_cshuffle_v1.hpp
View file @
b41172a3
...
...
@@ -694,7 +694,7 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V1
b1_grid_desc_g_n_k
,
c_grid_desc_g_m_n
,
z_grid_desc_g_m_n
,
type_convert
<
index_t
>
(
lse_g
rid_desc_m
.
GetElementSpaceSize
()
));
type_convert
<
index_t
>
(
lse_g
s_ms_strides
[
NumDimG
-
1
]
));
// C0 mask
const
auto
c0_matrix_mask
=
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_mha_fwd_xdl_cshuffle_v2.hpp
View file @
b41172a3
...
...
@@ -795,7 +795,7 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2
b1_grid_desc_g_n_k
,
c_grid_desc_g_m_n
,
z_grid_desc_g_m_n
,
type_convert
<
index_t
>
(
lse_grid_desc_m
.
GetElementSpaceSize
()
));
type_convert
<
index_t
>
(
problem_desc
.
lse_gs_ms_strides
[
NumDimG
-
1
]
));
// C0 mask
const
auto
c0_matrix_mask
=
...
...
include/ck/tensor_operation/gpu/grid/gridwise_batched_mha_bwd_xdl_cshuffle_qloop_b2t_v2.hpp
View file @
b41172a3
...
...
@@ -1620,7 +1620,11 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
p_y_grid
,
y_grid_desc_mblock_mperblock_oblock_operblock
.
GetElementSpaceSize
());
const
auto
lse_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_lse_grid
,
lse_grid_desc_m
.
GetElementSpaceSize
());
const
auto
ygrad_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
// ygrad dynamic buffer used for calculating y_dot_dy
const
auto
ygrad_grid_buf1
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_ygrad_grid
,
y_grid_desc_mblock_mperblock_oblock_operblock
.
GetElementSpaceSize
());
// ygrad dynamic buffer used for calculating dV = Pdrop^T * dY or dPdrop = dY * V^T
const
auto
ygrad_grid_buf2
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_ygrad_grid
,
ygrad_grid_desc_m0_o_m1
.
GetElementSpaceSize
());
auto
vgrad_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_vgrad_grid
,
v_grid_desc_o0_n_o1
.
GetElementSpaceSize
());
...
...
@@ -2208,7 +2212,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
make_tuple
(
I0
,
I0
,
I0
,
I0
),
y_thread_buf
);
yygrad_threadwise_copy
.
Run
(
y_grid_desc_mblock_mperblock_oblock_operblock
,
ygrad_grid_buf
,
ygrad_grid_buf
1
,
y_thread_desc_m0_m1_o0_o1
,
make_tuple
(
I0
,
I0
,
I0
,
I0
),
ygrad_thread_buf
);
...
...
@@ -2492,7 +2496,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
// preload data into LDS
vgrad_gemm_tile_ygrad_blockwise_copy
.
RunRead
(
ygrad_grid_desc_m0_o_m1
,
ygrad_grid_buf
);
ygrad_grid_buf
2
);
vgrad_gemm_tile_ygrad_blockwise_copy
.
MoveSrcSliceWindow
(
ygrad_grid_desc_m0_o_m1
,
Gemm1
::
b_block_slice_copy_step
);
...
...
@@ -2513,7 +2517,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
make_tuple
(
I0
,
I0
,
I0
),
gemm1_a_thread_buf
);
vgrad_gemm_tile_ygrad_blockwise_copy
.
RunRead
(
ygrad_grid_desc_m0_o_m1
,
ygrad_grid_buf
);
ygrad_grid_buf
2
);
block_sync_lds
();
...
...
@@ -2553,7 +2557,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
{
// preload data into LDS
pgrad_gemm_tile_ygrad_blockwise_copy
.
RunRead
(
ygrad_grid_desc_o0_m_o1
,
ygrad_grid_buf
);
ygrad_grid_buf
2
);
pgrad_gemm_tile_ygrad_blockwise_copy
.
MoveSrcSliceWindow
(
ygrad_grid_desc_o0_m_o1
,
Gemm0
::
a_block_slice_copy_step
);
...
...
@@ -2570,7 +2574,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
{
static_for
<
0
,
num_ok_block_main_loop
-
1
,
1
>
{}([
&
](
auto
i
)
{
pgrad_gemm_tile_ygrad_blockwise_copy
.
RunRead
(
ygrad_grid_desc_o0_m_o1
,
ygrad_grid_buf
);
ygrad_grid_buf
2
);
block_sync_lds
();
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment