Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
0999070d
"doc/vscode:/vscode.git/clone" did not exist on "2295ccd0c7e5bcba3602830cb9a503cda1bb350b"
Commit
0999070d
authored
Jul 15, 2023
by
danyao12
Browse files
Merge branch 'attn-train-develop-qloop' into mha-train-develop
parents
5ba30232
68e3bb6d
Changes
9
Show whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
182 additions
and
144 deletions
+182
-144
include/ck/tensor_operation/gpu/device/impl/device_batched_mha_bwd_xdl_cshuffle_kloop_v1.hpp
...ice/impl/device_batched_mha_bwd_xdl_cshuffle_kloop_v1.hpp
+1
-1
include/ck/tensor_operation/gpu/device/impl/device_batched_mha_bwd_xdl_cshuffle_kloop_v2.hpp
...ice/impl/device_batched_mha_bwd_xdl_cshuffle_kloop_v2.hpp
+1
-1
include/ck/tensor_operation/gpu/device/impl/device_batched_mha_bwd_xdl_cshuffle_qloop_phased_v1.hpp
...l/device_batched_mha_bwd_xdl_cshuffle_qloop_phased_v1.hpp
+1
-1
include/ck/tensor_operation/gpu/device/impl/device_batched_mha_bwd_xdl_cshuffle_qloop_v1.hpp
...ice/impl/device_batched_mha_bwd_xdl_cshuffle_qloop_v1.hpp
+42
-35
include/ck/tensor_operation/gpu/device/impl/device_batched_mha_bwd_xdl_cshuffle_qloop_v2.hpp
...ice/impl/device_batched_mha_bwd_xdl_cshuffle_qloop_v2.hpp
+16
-5
include/ck/tensor_operation/gpu/device/impl/device_grouped_mha_bwd_xdl_cshuffle_qloop_v1.hpp
...ice/impl/device_grouped_mha_bwd_xdl_cshuffle_qloop_v1.hpp
+13
-5
include/ck/tensor_operation/gpu/device/impl/device_grouped_mha_bwd_xdl_cshuffle_qloop_v2.hpp
...ice/impl/device_grouped_mha_bwd_xdl_cshuffle_qloop_v2.hpp
+13
-5
include/ck/tensor_operation/gpu/grid/gridwise_batched_mha_bwd_xdl_cshuffle_qloop_b2t_v1.hpp
...id/gridwise_batched_mha_bwd_xdl_cshuffle_qloop_b2t_v1.hpp
+46
-44
include/ck/tensor_operation/gpu/grid/gridwise_batched_mha_bwd_xdl_cshuffle_qloop_b2t_v2.hpp
...id/gridwise_batched_mha_bwd_xdl_cshuffle_qloop_b2t_v2.hpp
+49
-47
No files found.
include/ck/tensor_operation/gpu/device/impl/device_batched_mha_bwd_xdl_cshuffle_kloop_v1.hpp
View file @
0999070d
...
...
@@ -1047,7 +1047,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Kloop_Xdl_CShuffle_V1
static
bool
IsSupportedArgument
(
const
Argument
&
arg
)
{
#if
0
#if
DEBUG_LOG
arg
.
Print
();
#endif
...
...
include/ck/tensor_operation/gpu/device/impl/device_batched_mha_bwd_xdl_cshuffle_kloop_v2.hpp
View file @
0999070d
...
...
@@ -1048,7 +1048,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Kloop_Xdl_CShuffle_V2
static
bool
IsSupportedArgument
(
const
Argument
&
arg
)
{
#if
0
#if
DEBUG_LOG
arg
.
Print
();
#endif
...
...
include/ck/tensor_operation/gpu/device/impl/device_batched_mha_bwd_xdl_cshuffle_qloop_phased_v1.hpp
View file @
0999070d
...
...
@@ -1041,7 +1041,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Phased_Xdl_CShuffle_V1
static
bool
IsSupportedArgument
(
const
Argument
&
arg
)
{
#if
0
#if
DEBUG_LOG
arg
.
Print
();
#endif
...
...
include/ck/tensor_operation/gpu/device/impl/device_batched_mha_bwd_xdl_cshuffle_qloop_v1.hpp
View file @
0999070d
...
...
@@ -48,6 +48,7 @@ template <typename GridwiseGemm,
typename
ComputeBasePtrOfStridedBatch
,
typename
C0MatrixMask
,
bool
HasMainKBlockLoop
,
bool
IsDropout
,
bool
Deterministic
>
__global__
void
#if CK_USE_LAUNCH_BOUNDS
...
...
@@ -120,7 +121,7 @@ __global__ void
{
for
(
index_t
i
=
0
;
i
<
nblock
;
i
++
)
{
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
>(
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
,
IsDropout
>(
p_a_grid
+
a_batch_offset
,
p_b_grid
+
b_batch_offset
,
z_matrix_ptr
,
...
...
@@ -155,7 +156,7 @@ __global__ void
}
else
{
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
>(
p_a_grid
+
a_batch_offset
,
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
,
IsDropout
>(
p_a_grid
+
a_batch_offset
,
p_b_grid
+
b_batch_offset
,
z_matrix_ptr
,
p_b1_grid
+
b1_batch_offset
,
...
...
@@ -933,7 +934,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
float
ave_time
=
0
;
auto
launch_kernel
=
[
&
](
auto
has_main_k_block_loop_
)
{
auto
launch_kernel
=
[
&
](
auto
has_main_k_block_loop_
,
auto
is_dropout_
)
{
const
auto
kernel
=
kernel_batched_multihead_attention_backward_qloop_xdl_cshuffle_v1
<
GridwiseGemm
,
...
...
@@ -957,6 +958,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
ComputeBasePtrOfStridedBatch
,
C0MatrixMask
,
has_main_k_block_loop_
,
is_dropout_
,
Deterministic
>
;
return
launch_and_time_kernel
(
...
...
@@ -998,9 +1000,11 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
arg
.
m_raw_padded_
,
arg
.
n_raw_padded_
);
};
ave_time
=
launch_kernel
(
integral_constant
<
bool
,
false
>
{});
if
(
arg
.
p_drop_
>
0.0
){
ave_time
=
launch_kernel
(
integral_constant
<
bool
,
false
>
{},
integral_constant
<
bool
,
true
>
{});
}
else
{
ave_time
=
launch_kernel
(
integral_constant
<
bool
,
false
>
{},
integral_constant
<
bool
,
false
>
{});
}
return
ave_time
;
}
...
...
@@ -1020,6 +1024,9 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
static
bool
IsSupportedArgument
(
const
Argument
&
arg
)
{
#if DEBUG_LOG
arg
.
Print
();
#endif
if
(
!
(
ck
::
get_device_name
()
==
"gfx908"
||
ck
::
get_device_name
()
==
"gfx90a"
||
ck
::
get_device_name
()
==
"gfx940"
||
ck
::
get_device_name
()
==
"gfx941"
||
...
...
include/ck/tensor_operation/gpu/device/impl/device_batched_mha_bwd_xdl_cshuffle_qloop_v2.hpp
View file @
0999070d
...
...
@@ -47,6 +47,7 @@ template <typename GridwiseGemm,
typename
ComputeBasePtrOfStridedBatch
,
typename
C0MatrixMask
,
bool
HasMainKBlockLoop
,
bool
IsDropout
,
bool
Deterministic
>
__global__
void
#if CK_USE_LAUNCH_BOUNDS
...
...
@@ -119,7 +120,7 @@ __global__ void
{
for
(
index_t
i
=
0
;
i
<
nblock
;
i
++
)
{
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
>(
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
,
IsDropout
>(
p_a_grid
+
a_batch_offset
,
p_b_grid
+
b_batch_offset
,
z_matrix_ptr
,
...
...
@@ -154,7 +155,7 @@ __global__ void
}
else
{
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
>(
p_a_grid
+
a_batch_offset
,
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
,
IsDropout
>(
p_a_grid
+
a_batch_offset
,
p_b_grid
+
b_batch_offset
,
z_matrix_ptr
,
p_b1_grid
+
b1_batch_offset
,
...
...
@@ -950,7 +951,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
float
ave_time
=
0
;
auto
launch_kernel
=
[
&
](
auto
has_main_k_block_loop_
)
{
auto
launch_kernel
=
[
&
](
auto
has_main_k_block_loop_
,
auto
is_dropout_
)
{
const
auto
kernel
=
kernel_batched_multihead_attention_backward_qloop_xdl_cshuffle_v2
<
GridwiseGemm
,
...
...
@@ -974,6 +975,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
ComputeBasePtrOfStridedBatch
,
C0MatrixMask
,
has_main_k_block_loop_
,
is_dropout_
,
Deterministic
>
;
return
launch_and_time_kernel
(
...
...
@@ -1021,11 +1023,17 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
if
(
GridwiseGemm
::
CalculateHasMainKBlockLoop
(
K
))
{
ave_time
=
launch_kernel
(
integral_constant
<
bool
,
true
>
{});
if
(
arg
.
p_drop_
>
0.0
)
ave_time
=
launch_kernel
(
integral_constant
<
bool
,
true
>
{},
integral_constant
<
bool
,
true
>
{});
else
ave_time
=
launch_kernel
(
integral_constant
<
bool
,
true
>
{},
integral_constant
<
bool
,
false
>
{});
}
else
{
ave_time
=
launch_kernel
(
integral_constant
<
bool
,
false
>
{});
if
(
arg
.
p_drop_
>
0.0
)
ave_time
=
launch_kernel
(
integral_constant
<
bool
,
false
>
{},
integral_constant
<
bool
,
true
>
{});
else
ave_time
=
launch_kernel
(
integral_constant
<
bool
,
false
>
{},
integral_constant
<
bool
,
false
>
{});
}
return
ave_time
;
...
...
@@ -1047,6 +1055,9 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
static
bool
IsSupportedArgument
(
const
Argument
&
arg
)
{
#if DEBUG_LOG
arg
.
Print
();
#endif
if
(
!
(
ck
::
get_device_name
()
==
"gfx908"
||
ck
::
get_device_name
()
==
"gfx90a"
||
ck
::
get_device_name
()
==
"gfx940"
||
ck
::
get_device_name
()
==
"gfx941"
||
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_mha_bwd_xdl_cshuffle_qloop_v1.hpp
View file @
0999070d
...
...
@@ -35,6 +35,7 @@ template <typename GridwiseGemm,
typename
B1ElementwiseOperation
,
typename
CElementwiseOperation
,
bool
HasMainKBlockLoop
,
bool
IsDropout
,
bool
Deterministic
>
__global__
void
#if CK_USE_LAUNCH_BOUNDS
...
...
@@ -105,7 +106,7 @@ __global__ void
{
for
(
index_t
i
=
0
;
i
<
num_blocks_per_batch
;
i
++
)
{
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
>(
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
,
IsDropout
>(
arg_ptr
[
group_id
].
p_a_grid_
+
a_batch_offset
,
arg_ptr
[
group_id
].
p_b_grid_
+
b_batch_offset
,
z_matrix_ptr
,
...
...
@@ -141,7 +142,7 @@ __global__ void
}
else
{
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
>(
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
,
IsDropout
>(
arg_ptr
[
group_id
].
p_a_grid_
+
a_batch_offset
,
arg_ptr
[
group_id
].
p_b_grid_
+
b_batch_offset
,
z_matrix_ptr
,
...
...
@@ -961,7 +962,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
float
ave_time
=
0
;
auto
launch_kernel
=
[
&
](
auto
has_main_k_block_loop_
)
{
auto
launch_kernel
=
[
&
](
auto
has_main_k_block_loop_
,
auto
is_dropout_
)
{
const
auto
kernel
=
kernel_grouped_multihead_attention_backward_qloop_xdl_cshuffle_v1
<
GridwiseGemm
,
...
...
@@ -972,6 +973,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
B1ElementwiseOperation
,
CElementwiseOperation
,
has_main_k_block_loop_
,
is_dropout_
,
Deterministic
>
;
return
launch_and_time_kernel
(
...
...
@@ -996,11 +998,17 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
// to concern Gemm0's loop
if
(
all_has_main_k_block_loop
)
{
ave_time
=
launch_kernel
(
integral_constant
<
bool
,
true
>
{});
if
(
arg
.
p_dropout_
>
0.0
)
ave_time
=
launch_kernel
(
integral_constant
<
bool
,
true
>
{},
integral_constant
<
bool
,
true
>
{});
else
ave_time
=
launch_kernel
(
integral_constant
<
bool
,
true
>
{},
integral_constant
<
bool
,
false
>
{});
}
else
if
(
!
some_has_main_k_block_loop
)
{
ave_time
=
launch_kernel
(
integral_constant
<
bool
,
false
>
{});
if
(
arg
.
p_dropout_
>
0.0
)
ave_time
=
launch_kernel
(
integral_constant
<
bool
,
false
>
{},
integral_constant
<
bool
,
true
>
{});
else
ave_time
=
launch_kernel
(
integral_constant
<
bool
,
false
>
{},
integral_constant
<
bool
,
false
>
{});
}
else
{
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_mha_bwd_xdl_cshuffle_qloop_v2.hpp
View file @
0999070d
...
...
@@ -35,6 +35,7 @@ template <typename GridwiseGemm,
typename
B1ElementwiseOperation
,
typename
CElementwiseOperation
,
bool
HasMainKBlockLoop
,
bool
IsDropout
,
bool
Deterministic
>
__global__
void
#if CK_USE_LAUNCH_BOUNDS
...
...
@@ -105,7 +106,7 @@ __global__ void
{
for
(
index_t
i
=
0
;
i
<
num_blocks_per_batch
;
i
++
)
{
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
>(
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
,
IsDropout
>(
arg_ptr
[
group_id
].
p_a_grid_
+
a_batch_offset
,
arg_ptr
[
group_id
].
p_b_grid_
+
b_batch_offset
,
z_matrix_ptr
,
...
...
@@ -141,7 +142,7 @@ __global__ void
}
else
{
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
>(
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
,
IsDropout
>(
arg_ptr
[
group_id
].
p_a_grid_
+
a_batch_offset
,
arg_ptr
[
group_id
].
p_b_grid_
+
b_batch_offset
,
z_matrix_ptr
,
...
...
@@ -968,7 +969,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
float
ave_time
=
0
;
auto
launch_kernel
=
[
&
](
auto
has_main_k_block_loop_
)
{
auto
launch_kernel
=
[
&
](
auto
has_main_k_block_loop_
,
auto
is_dropout_
)
{
const
auto
kernel
=
kernel_grouped_multihead_attention_backward_qloop_xdl_cshuffle_v2
<
GridwiseGemm
,
...
...
@@ -979,6 +980,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
B1ElementwiseOperation
,
CElementwiseOperation
,
has_main_k_block_loop_
,
is_dropout_
,
Deterministic
>
;
return
launch_and_time_kernel
(
...
...
@@ -1003,11 +1005,17 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
// to concern Gemm0's loop
if
(
all_has_main_k_block_loop
)
{
ave_time
=
launch_kernel
(
integral_constant
<
bool
,
true
>
{});
if
(
arg
.
p_dropout_
>
0.0
)
ave_time
=
launch_kernel
(
integral_constant
<
bool
,
true
>
{},
integral_constant
<
bool
,
true
>
{});
else
ave_time
=
launch_kernel
(
integral_constant
<
bool
,
true
>
{},
integral_constant
<
bool
,
false
>
{});
}
else
if
(
!
some_has_main_k_block_loop
)
{
ave_time
=
launch_kernel
(
integral_constant
<
bool
,
false
>
{});
if
(
arg
.
p_dropout_
>
0.0
)
ave_time
=
launch_kernel
(
integral_constant
<
bool
,
false
>
{},
integral_constant
<
bool
,
true
>
{});
else
ave_time
=
launch_kernel
(
integral_constant
<
bool
,
false
>
{},
integral_constant
<
bool
,
false
>
{});
}
else
{
...
...
include/ck/tensor_operation/gpu/grid/gridwise_batched_mha_bwd_xdl_cshuffle_qloop_b2t_v1.hpp
View file @
0999070d
...
...
@@ -1222,6 +1222,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
}
template
<
bool
HasMainKBlockLoop
,
bool
IsDropout
,
typename
Block2CTileMap
,
typename
C0MatrixMask
,
typename
YGradGridDesc_O0_M_O1
>
...
...
@@ -1947,6 +1948,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
constexpr
auto
position_offset
=
M3
*
M4
;
// save z to global
if
constexpr
(
IsDropout
){
if
(
p_z_grid
)
{
...
...
@@ -1996,7 +1998,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
true
>(
s_slash_p_thread_buf
,
ph
,
global_elem_id
,
raw_n_padded
);
}
}
block_sync_lds
();
// wait for gemm1 LDS read
// dS = P * (dP - Y_dot_dY)
...
...
include/ck/tensor_operation/gpu/grid/gridwise_batched_mha_bwd_xdl_cshuffle_qloop_b2t_v2.hpp
View file @
0999070d
...
...
@@ -1154,6 +1154,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
}
template
<
bool
HasMainKBlockLoop
,
bool
IsDropout
,
typename
Block2CTileMap
,
typename
C0MatrixMask
,
typename
YGradGridDesc_M0_O_M1
>
...
...
@@ -1863,6 +1864,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
constexpr
auto
position_offset
=
M3
*
M4
;
// save z to global
if
constexpr
(
IsDropout
){
if
(
p_z_grid
)
{
...
...
@@ -1911,7 +1913,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
true
>(
s_slash_p_thread_buf
,
ph
,
global_elem_id
,
raw_n_padded
);
}
}
block_sync_lds
();
// wait for gemm1 LDS read
// gemm dV
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment