Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
7e402f6a
Commit
7e402f6a
authored
Mar 16, 2023
by
guangzlu
Browse files
added switch for dropout in bwd pass
parent
665b08cf
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
243 additions
and
147 deletions
+243
-147
include/ck/tensor_operation/gpu/device/impl/device_batched_multihead_attention_backward_xdl_cshuffle_v1.hpp
..._batched_multihead_attention_backward_xdl_cshuffle_v1.hpp
+49
-34
include/ck/tensor_operation/gpu/device/impl/device_batched_multihead_attention_backward_xdl_cshuffle_v2.hpp
..._batched_multihead_attention_backward_xdl_cshuffle_v2.hpp
+58
-35
include/ck/tensor_operation/gpu/device/impl/device_grouped_multihead_attention_backward_xdl_cshuffle_v1.hpp
..._grouped_multihead_attention_backward_xdl_cshuffle_v1.hpp
+30
-8
include/ck/tensor_operation/gpu/device/impl/device_grouped_multihead_attention_backward_xdl_cshuffle_v2.hpp
..._grouped_multihead_attention_backward_xdl_cshuffle_v2.hpp
+30
-8
include/ck/tensor_operation/gpu/grid/gridwise_batched_multihead_attention_backward_xdl_cshuffle_pt1.hpp
...batched_multihead_attention_backward_xdl_cshuffle_pt1.hpp
+38
-31
include/ck/tensor_operation/gpu/grid/gridwise_batched_multihead_attention_backward_xdl_cshuffle_pt2.hpp
...batched_multihead_attention_backward_xdl_cshuffle_pt2.hpp
+38
-31
No files found.
include/ck/tensor_operation/gpu/device/impl/device_batched_multihead_attention_backward_xdl_cshuffle_v1.hpp
View file @
7e402f6a
...
...
@@ -47,7 +47,8 @@ template <typename GridwiseGemm,
typename
Block2CTileMap
,
typename
ComputeBasePtrOfStridedBatch
,
typename
C0MatrixMask
,
bool
HasMainKBlockLoop
>
bool
HasMainKBlockLoop
,
bool
IsDropout
>
__global__
void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
/*CK_MIN_BLOCK_PER_CU*/
1
)
...
...
@@ -111,34 +112,35 @@ __global__ void
ck
::
philox
ph
(
seed
,
global_thread_id
,
offset
);
ZDataType
*
z_matrix_ptr
=
(
p_z_grid
==
nullptr
?
nullptr
:
p_z_grid
+
z_batch_offset
);
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
>(
p_a_grid
+
a_batch_offset
,
p_b_grid
+
b_batch_offset
,
z_matrix_ptr
,
p_b1_grid
+
b1_batch_offset
,
p_c_grid
+
c_batch_offset
,
p_lse_grid
+
lse_batch_offset
,
p_ygrad_grid
+
c_batch_offset
,
p_qgrad_grid
+
a_batch_offset
,
p_kgrad_grid
+
b_batch_offset
,
p_vgrad_grid
+
b1_batch_offset
,
p_shared
,
a_element_op
,
b_element_op
,
acc_element_op
,
b1_element_op
,
c_element_op
,
a_grid_desc_ak0_m_ak1
,
b_grid_desc_bk0_n_bk1
,
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
,
b1_grid_desc_bk0_n_bk1
,
c_grid_desc_mblock_mperblock_nblock_nperblock
,
lse_grid_desc_m
,
vgrad_grid_desc_n_o
,
ygrad_grid_desc_o0_m_o1
,
block_2_ctile_map
,
c0_matrix_mask
,
p_drop
,
ph
);
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
,
IsDropout
>(
p_a_grid
+
a_batch_offset
,
p_b_grid
+
b_batch_offset
,
z_matrix_ptr
,
p_b1_grid
+
b1_batch_offset
,
p_c_grid
+
c_batch_offset
,
p_lse_grid
+
lse_batch_offset
,
p_ygrad_grid
+
c_batch_offset
,
p_qgrad_grid
+
a_batch_offset
,
p_kgrad_grid
+
b_batch_offset
,
p_vgrad_grid
+
b1_batch_offset
,
p_shared
,
a_element_op
,
b_element_op
,
acc_element_op
,
b1_element_op
,
c_element_op
,
a_grid_desc_ak0_m_ak1
,
b_grid_desc_bk0_n_bk1
,
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
,
b1_grid_desc_bk0_n_bk1
,
c_grid_desc_mblock_mperblock_nblock_nperblock
,
lse_grid_desc_m
,
vgrad_grid_desc_n_o
,
ygrad_grid_desc_o0_m_o1
,
block_2_ctile_map
,
c0_matrix_mask
,
p_drop
,
ph
);
#else
ignore
=
p_a_grid
;
ignore
=
p_b_grid
;
...
...
@@ -786,8 +788,9 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
y_grid_desc_m_o_
);
}
seed_
=
std
::
get
<
0
>
(
seeds
);
offset_
=
std
::
get
<
1
>
(
seeds
);
is_dropout_
=
p_drop_
>
0.0
;
seed_
=
std
::
get
<
0
>
(
seeds
);
offset_
=
std
::
get
<
1
>
(
seeds
);
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_
=
GridwiseGemm
::
MakeCGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5
(
z_grid_desc_m_n_
);
...
...
@@ -877,6 +880,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
ComputeBasePtrOfStridedBatch
compute_base_ptr_of_batch_
;
float
p_drop_
;
bool
is_dropout_
;
unsigned
long
long
seed_
;
unsigned
long
long
offset_
;
};
...
...
@@ -898,7 +902,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
float
ave_time
=
0
;
auto
launch_kernel
=
[
&
](
auto
has_main_k_block_loop_
)
{
auto
launch_kernel
=
[
&
](
auto
has_main_k_block_loop_
,
auto
is_dropout_
)
{
const
auto
kernel
=
kernel_batched_multihead_attention_backward_xdl_cshuffle_v1
<
GridwiseGemm
,
DataType
,
...
...
@@ -920,7 +924,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
typename
GridwiseGemm
::
DefaultBlock2CTileMap
,
ComputeBasePtrOfStridedBatch
,
C0MatrixMask
,
has_main_k_block_loop_
>
;
has_main_k_block_loop_
,
is_dropout_
>
;
return
launch_and_time_kernel
(
stream_config
,
kernel
,
...
...
@@ -970,7 +975,17 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
// {
// ave_time = launch_kernel(integral_constant<bool, false>{});
// }
ave_time
=
launch_kernel
(
integral_constant
<
bool
,
false
>
{});
if
(
arg
.
is_dropout_
)
{
ave_time
=
launch_kernel
(
integral_constant
<
bool
,
false
>
{},
integral_constant
<
bool
,
true
>
{});
}
else
{
ave_time
=
launch_kernel
(
integral_constant
<
bool
,
false
>
{},
integral_constant
<
bool
,
false
>
{});
}
// ave_time = launch_kernel(integral_constant<bool, false>{});
#endif
return
ave_time
;
}
...
...
include/ck/tensor_operation/gpu/device/impl/device_batched_multihead_attention_backward_xdl_cshuffle_v2.hpp
View file @
7e402f6a
...
...
@@ -46,7 +46,8 @@ template <typename GridwiseGemm,
typename
Block2CTileMap
,
typename
ComputeBasePtrOfStridedBatch
,
typename
C0MatrixMask
,
bool
HasMainKBlockLoop
>
bool
HasMainKBlockLoop
,
bool
IsDropout
>
__global__
void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
/*CK_MIN_BLOCK_PER_CU*/
1
)
...
...
@@ -110,34 +111,35 @@ __global__ void
ck
::
philox
ph
(
seed
,
global_thread_id
,
offset
);
ZDataType
*
z_matrix_ptr
=
(
p_z_grid
==
nullptr
?
nullptr
:
p_z_grid
+
z_batch_offset
);
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
>(
p_a_grid
+
a_batch_offset
,
p_b_grid
+
b_batch_offset
,
z_matrix_ptr
,
p_b1_grid
+
b1_batch_offset
,
p_c_grid
+
c_batch_offset
,
p_lse_grid
+
lse_batch_offset
,
p_ygrad_grid
+
c_batch_offset
,
p_qgrad_grid
+
a_batch_offset
,
p_kgrad_grid
+
b_batch_offset
,
p_vgrad_grid
+
b1_batch_offset
,
p_shared
,
a_element_op
,
b_element_op
,
acc_element_op
,
b1_element_op
,
c_element_op
,
a_grid_desc_ak0_m_ak1
,
b_grid_desc_bk0_n_bk1
,
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
,
b1_grid_desc_bk0_n_bk1
,
c_grid_desc_mblock_mperblock_nblock_nperblock
,
lse_grid_desc_m
,
vgrad_grid_desc_n_o
,
ygrad_grid_desc_m0_o_m1
,
block_2_ctile_map
,
c0_matrix_mask
,
p_drop
,
ph
);
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
,
IsDropout
>(
p_a_grid
+
a_batch_offset
,
p_b_grid
+
b_batch_offset
,
z_matrix_ptr
,
p_b1_grid
+
b1_batch_offset
,
p_c_grid
+
c_batch_offset
,
p_lse_grid
+
lse_batch_offset
,
p_ygrad_grid
+
c_batch_offset
,
p_qgrad_grid
+
a_batch_offset
,
p_kgrad_grid
+
b_batch_offset
,
p_vgrad_grid
+
b1_batch_offset
,
p_shared
,
a_element_op
,
b_element_op
,
acc_element_op
,
b1_element_op
,
c_element_op
,
a_grid_desc_ak0_m_ak1
,
b_grid_desc_bk0_n_bk1
,
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
,
b1_grid_desc_bk0_n_bk1
,
c_grid_desc_mblock_mperblock_nblock_nperblock
,
lse_grid_desc_m
,
vgrad_grid_desc_n_o
,
ygrad_grid_desc_m0_o_m1
,
block_2_ctile_map
,
c0_matrix_mask
,
p_drop
,
ph
);
#else
ignore
=
p_a_grid
;
ignore
=
p_b_grid
;
...
...
@@ -784,8 +786,9 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
y_grid_desc_m_o_
);
}
seed_
=
std
::
get
<
0
>
(
seeds
);
offset_
=
std
::
get
<
1
>
(
seeds
);
is_dropout_
=
p_drop_
>
0.0
;
seed_
=
std
::
get
<
0
>
(
seeds
);
offset_
=
std
::
get
<
1
>
(
seeds
);
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_
=
GridwiseGemm
::
MakeCGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5
(
z_grid_desc_m_n_
);
...
...
@@ -875,6 +878,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
ComputeBasePtrOfStridedBatch
compute_base_ptr_of_batch_
;
float
p_drop_
;
bool
is_dropout_
;
unsigned
long
long
seed_
;
unsigned
long
long
offset_
;
};
...
...
@@ -900,7 +904,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
float
ave_time
=
0
;
auto
launch_kernel
=
[
&
](
auto
has_main_k_block_loop_
)
{
auto
launch_kernel
=
[
&
](
auto
has_main_k_block_loop_
,
auto
is_dropout_
)
{
const
auto
kernel
=
kernel_batched_multihead_attention_backward_xdl_cshuffle_v2
<
GridwiseGemm
,
DataType
,
...
...
@@ -922,7 +926,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
typename
GridwiseGemm
::
DefaultBlock2CTileMap
,
ComputeBasePtrOfStridedBatch
,
C0MatrixMask
,
has_main_k_block_loop_
>
;
has_main_k_block_loop_
,
is_dropout_
>
;
return
launch_and_time_kernel
(
stream_config
,
kernel
,
...
...
@@ -966,11 +971,29 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
#if 1
if
(
GridwiseGemm
::
CalculateHasMainKBlockLoop
(
K
))
{
ave_time
=
launch_kernel
(
integral_constant
<
bool
,
true
>
{});
if
(
arg
.
is_dropout_
)
{
ave_time
=
launch_kernel
(
integral_constant
<
bool
,
true
>
{},
integral_constant
<
bool
,
true
>
{});
}
else
{
ave_time
=
launch_kernel
(
integral_constant
<
bool
,
true
>
{},
integral_constant
<
bool
,
false
>
{});
}
}
else
{
ave_time
=
launch_kernel
(
integral_constant
<
bool
,
false
>
{});
if
(
arg
.
is_dropout_
)
{
ave_time
=
launch_kernel
(
integral_constant
<
bool
,
false
>
{},
integral_constant
<
bool
,
true
>
{});
}
else
{
ave_time
=
launch_kernel
(
integral_constant
<
bool
,
false
>
{},
integral_constant
<
bool
,
false
>
{});
}
}
#endif
return
ave_time
;
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_multihead_attention_backward_xdl_cshuffle_v1.hpp
View file @
7e402f6a
...
...
@@ -34,7 +34,8 @@ template <typename GridwiseGemm,
typename
AccElementwiseOperation
,
typename
B1ElementwiseOperation
,
typename
CElementwiseOperation
,
bool
HasMainKBlockLoop
>
bool
HasMainKBlockLoop
,
bool
IsDropout
>
__global__
void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
/*CK_MIN_BLOCK_PER_CU*/
1
)
...
...
@@ -99,7 +100,7 @@ __global__ void
(
arg_ptr
[
group_id
].
p_z_grid_
==
nullptr
?
nullptr
:
arg_ptr
[
group_id
].
p_z_grid_
+
z_batch_offset
);
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
>(
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
,
IsDropout
>(
arg_ptr
[
group_id
].
p_a_grid_
+
a_batch_offset
,
arg_ptr
[
group_id
].
p_b_grid_
+
b_batch_offset
,
z_matrix_ptr
,
...
...
@@ -685,8 +686,9 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V1
c_element_op_
{
c_element_op
},
p_dropout_
{
p_drop
}
{
seed_
=
std
::
get
<
0
>
(
seeds
);
offset_
=
std
::
get
<
1
>
(
seeds
);
is_dropout_
=
p_drop
>
0.0
;
seed_
=
std
::
get
<
0
>
(
seeds
);
offset_
=
std
::
get
<
1
>
(
seeds
);
group_count_
=
ck
::
type_convert
<
ck
::
index_t
>
(
problem_desc_vec
.
size
());
...
...
@@ -867,6 +869,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V1
CElementwiseOperation
c_element_op_
;
float
p_dropout_
;
bool
is_dropout_
;
unsigned
long
long
seed_
;
unsigned
long
long
offset_
;
...
...
@@ -908,7 +911,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V1
float
ave_time
=
0
;
auto
launch_kernel
=
[
&
](
auto
has_main_k_block_loop_
)
{
auto
launch_kernel
=
[
&
](
auto
has_main_k_block_loop_
,
auto
is_dropout_
)
{
const
auto
kernel
=
kernel_grouped_multihead_attention_backward_xdl_cshuffle_v1
<
GridwiseGemm
,
GroupKernelArg
,
...
...
@@ -917,7 +920,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V1
AccElementwiseOperation
,
B1ElementwiseOperation
,
CElementwiseOperation
,
has_main_k_block_loop_
>
;
has_main_k_block_loop_
,
is_dropout_
>
;
return
launch_and_time_kernel
(
stream_config
,
...
...
@@ -941,11 +945,29 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V1
// to concern Gemm0's loop
if
(
all_has_main_k_block_loop
)
{
ave_time
=
launch_kernel
(
integral_constant
<
bool
,
true
>
{});
if
(
arg
.
is_dropout_
)
{
ave_time
=
launch_kernel
(
integral_constant
<
bool
,
true
>
{},
integral_constant
<
bool
,
true
>
{});
}
else
{
ave_time
=
launch_kernel
(
integral_constant
<
bool
,
true
>
{},
integral_constant
<
bool
,
false
>
{});
}
}
else
if
(
!
some_has_main_k_block_loop
)
{
ave_time
=
launch_kernel
(
integral_constant
<
bool
,
false
>
{});
if
(
arg
.
is_dropout_
)
{
ave_time
=
launch_kernel
(
integral_constant
<
bool
,
false
>
{},
integral_constant
<
bool
,
true
>
{});
}
else
{
ave_time
=
launch_kernel
(
integral_constant
<
bool
,
false
>
{},
integral_constant
<
bool
,
false
>
{});
}
}
else
{
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_multihead_attention_backward_xdl_cshuffle_v2.hpp
View file @
7e402f6a
...
...
@@ -34,7 +34,8 @@ template <typename GridwiseGemm,
typename
AccElementwiseOperation
,
typename
B1ElementwiseOperation
,
typename
CElementwiseOperation
,
bool
HasMainKBlockLoop
>
bool
HasMainKBlockLoop
,
bool
IsDropout
>
__global__
void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
/*CK_MIN_BLOCK_PER_CU*/
1
)
...
...
@@ -99,7 +100,7 @@ __global__ void
(
arg_ptr
[
group_id
].
p_z_grid_
==
nullptr
?
nullptr
:
arg_ptr
[
group_id
].
p_z_grid_
+
z_batch_offset
);
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
>(
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
,
IsDropout
>(
arg_ptr
[
group_id
].
p_a_grid_
+
a_batch_offset
,
arg_ptr
[
group_id
].
p_b_grid_
+
b_batch_offset
,
z_matrix_ptr
,
...
...
@@ -678,8 +679,9 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V2
c_element_op_
{
c_element_op
},
p_dropout_
{
p_drop
}
{
seed_
=
std
::
get
<
0
>
(
seeds
);
offset_
=
std
::
get
<
1
>
(
seeds
);
is_dropout_
=
p_drop
>
0.0
;
seed_
=
std
::
get
<
0
>
(
seeds
);
offset_
=
std
::
get
<
1
>
(
seeds
);
group_count_
=
ck
::
type_convert
<
ck
::
index_t
>
(
problem_desc_vec
.
size
());
...
...
@@ -860,6 +862,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V2
CElementwiseOperation
c_element_op_
;
float
p_dropout_
;
bool
is_dropout_
;
unsigned
long
long
seed_
;
unsigned
long
long
offset_
;
...
...
@@ -900,7 +903,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V2
float
ave_time
=
0
;
auto
launch_kernel
=
[
&
](
auto
has_main_k_block_loop_
)
{
auto
launch_kernel
=
[
&
](
auto
has_main_k_block_loop_
,
auto
is_dropout_
)
{
const
auto
kernel
=
kernel_grouped_multihead_attention_backward_xdl_cshuffle_v2
<
GridwiseGemm
,
GroupKernelArg
,
...
...
@@ -909,7 +912,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V2
AccElementwiseOperation
,
B1ElementwiseOperation
,
CElementwiseOperation
,
has_main_k_block_loop_
>
;
has_main_k_block_loop_
,
is_dropout_
>
;
return
launch_and_time_kernel
(
stream_config
,
...
...
@@ -933,11 +937,29 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V2
// to concern Gemm0's loop
if
(
all_has_main_k_block_loop
)
{
ave_time
=
launch_kernel
(
integral_constant
<
bool
,
true
>
{});
if
(
arg
.
is_dropout_
)
{
ave_time
=
launch_kernel
(
integral_constant
<
bool
,
true
>
{},
integral_constant
<
bool
,
true
>
{});
}
else
{
ave_time
=
launch_kernel
(
integral_constant
<
bool
,
true
>
{},
integral_constant
<
bool
,
false
>
{});
}
}
else
if
(
!
some_has_main_k_block_loop
)
{
ave_time
=
launch_kernel
(
integral_constant
<
bool
,
false
>
{});
if
(
arg
.
is_dropout_
)
{
ave_time
=
launch_kernel
(
integral_constant
<
bool
,
false
>
{},
integral_constant
<
bool
,
true
>
{});
}
else
{
ave_time
=
launch_kernel
(
integral_constant
<
bool
,
false
>
{},
integral_constant
<
bool
,
false
>
{});
}
}
else
{
...
...
include/ck/tensor_operation/gpu/grid/gridwise_batched_multihead_attention_backward_xdl_cshuffle_pt1.hpp
View file @
7e402f6a
...
...
@@ -1230,6 +1230,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
}
template
<
bool
HasMainKBlockLoop
,
bool
IsDropout
,
typename
Block2CTileMap
,
typename
C0MatrixMask
,
typename
VGradGridDescriptor_N_O
,
...
...
@@ -1957,38 +1958,44 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
// scaling is already performed in the preceding statements with s_element_op
blockwise_softmax
.
RunWithPreCalcStats
(
s_slash_p_thread_buf
,
lse_thread_buf
);
// save z to global
if
(
p_z_grid
)
if
constexpr
(
IsDropout
)
{
// P_dropped
static_for
<
0
,
n0
,
1
>
{}([
&
](
auto
i
)
{
blockwise_dropout
.
template
ApplyDropout
<
decltype
(
s_slash_p_thread_buf
),
decltype
(
z_tenor_buffer
),
true
,
decltype
(
n0
),
decltype
(
i
)>(
s_slash_p_thread_buf
,
ph
,
z_tenor_buffer
);
z_thread_copy_vgpr_to_global
.
Run
(
z_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
,
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
),
z_tenor_buffer
,
// save z to global
if
(
p_z_grid
)
{
// P_dropped
static_for
<
0
,
n0
,
1
>
{}([
&
](
auto
i
)
{
blockwise_dropout
.
template
ApplyDropout
<
decltype
(
s_slash_p_thread_buf
),
decltype
(
z_tenor_buffer
),
true
,
decltype
(
n0
),
decltype
(
i
)>(
s_slash_p_thread_buf
,
ph
,
z_tenor_buffer
);
z_thread_copy_vgpr_to_global
.
Run
(
z_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
,
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
),
z_tenor_buffer
,
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
,
z_grid_buf
);
z_thread_copy_vgpr_to_global
.
MoveDstSliceWindow
(
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
,
make_multi_index
(
0
,
0
,
0
,
1
,
0
,
0
,
0
,
0
,
0
,
0
));
});
z_thread_copy_vgpr_to_global
.
MoveDstSliceWindow
(
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
,
z_grid_buf
);
make_multi_index
(
0
,
0
,
0
,
-
n0
.
value
,
0
,
0
,
0
,
0
,
0
,
0
)
);
z_thread_copy_vgpr_to_global
.
MoveDstSliceWindow
(
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
,
make_multi_index
(
0
,
0
,
0
,
1
,
0
,
0
,
0
,
0
,
0
,
0
));
});
z_thread_copy_vgpr_to_global
.
MoveDstSliceWindow
(
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
,
make_multi_index
(
0
,
0
,
0
,
-
n0
.
value
,
0
,
0
,
0
,
0
,
0
,
0
));
}
else
{
ignore
=
z_grid_buf
;
// P_dropped
blockwise_dropout
.
template
ApplyDropout
<
decltype
(
s_slash_p_thread_buf
),
true
>(
s_slash_p_thread_buf
,
ph
);
make_multi_index
(
0
,
1
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
));
}
else
{
ignore
=
z_grid_buf
;
// P_dropped
blockwise_dropout
.
template
ApplyDropout
<
decltype
(
s_slash_p_thread_buf
),
true
>(
s_slash_p_thread_buf
,
ph
);
}
}
block_sync_lds
();
// wait for gemm1 LDS read
...
...
@@ -2176,9 +2183,9 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
Gemm2
::
b_block_reset_copy_step
);
// rewind M
kgrad_thread_copy_vgpr_to_global
.
MoveDstSliceWindow
(
kgrad_grid_desc_n0_o0_n1_o1_n2_o2_o3_o4
,
Gemm2
::
c_block_slice_copy_step
);
// step N
z_thread_copy_vgpr_to_global
.
MoveDstSliceWindow
(
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
,
make_multi_index
(
0
,
1
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
));
//
z_thread_copy_vgpr_to_global.MoveDstSliceWindow(
//
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
//
make_multi_index(0, 1, 0, 0, 0, 0, 0, 0, 0, 0));
}
while
(
++
gemm1_k_block_outer_index
<
num_gemm1_k_block_outer_loop
);
// end j loop
...
...
include/ck/tensor_operation/gpu/grid/gridwise_batched_multihead_attention_backward_xdl_cshuffle_pt2.hpp
View file @
7e402f6a
...
...
@@ -1140,6 +1140,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
}
template
<
bool
HasMainKBlockLoop
,
bool
IsDropout
,
typename
Block2CTileMap
,
typename
C0MatrixMask
,
typename
VGradGridDescriptor_N_O
,
...
...
@@ -1852,38 +1853,44 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
// scaling is already performed in the preceding statements with s_element_op
blockwise_softmax
.
RunWithPreCalcStats
(
s_slash_p_thread_buf
,
lse_thread_buf
);
// save z to global
if
(
p_z_grid
)
if
constexpr
(
IsDropout
)
{
// P_dropped
static_for
<
0
,
n0
,
1
>
{}([
&
](
auto
i
)
{
blockwise_dropout
.
template
ApplyDropout
<
decltype
(
s_slash_p_thread_buf
),
decltype
(
z_tenor_buffer
),
true
,
decltype
(
n0
),
decltype
(
i
)>(
s_slash_p_thread_buf
,
ph
,
z_tenor_buffer
);
z_thread_copy_vgpr_to_global
.
Run
(
z_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
,
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
),
z_tenor_buffer
,
// save z to global
if
(
p_z_grid
)
{
// P_dropped
static_for
<
0
,
n0
,
1
>
{}([
&
](
auto
i
)
{
blockwise_dropout
.
template
ApplyDropout
<
decltype
(
s_slash_p_thread_buf
),
decltype
(
z_tenor_buffer
),
true
,
decltype
(
n0
),
decltype
(
i
)>(
s_slash_p_thread_buf
,
ph
,
z_tenor_buffer
);
z_thread_copy_vgpr_to_global
.
Run
(
z_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
,
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
),
z_tenor_buffer
,
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
,
z_grid_buf
);
z_thread_copy_vgpr_to_global
.
MoveDstSliceWindow
(
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
,
make_multi_index
(
0
,
0
,
0
,
1
,
0
,
0
,
0
,
0
,
0
,
0
));
});
z_thread_copy_vgpr_to_global
.
MoveDstSliceWindow
(
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
,
z_grid_buf
);
make_multi_index
(
0
,
0
,
0
,
-
n0
.
value
,
0
,
0
,
0
,
0
,
0
,
0
)
);
z_thread_copy_vgpr_to_global
.
MoveDstSliceWindow
(
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
,
make_multi_index
(
0
,
0
,
0
,
1
,
0
,
0
,
0
,
0
,
0
,
0
));
});
z_thread_copy_vgpr_to_global
.
MoveDstSliceWindow
(
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
,
make_multi_index
(
0
,
0
,
0
,
-
n0
.
value
,
0
,
0
,
0
,
0
,
0
,
0
));
}
else
{
ignore
=
z_grid_buf
;
// P_dropped
blockwise_dropout
.
template
ApplyDropout
<
decltype
(
s_slash_p_thread_buf
),
true
>(
s_slash_p_thread_buf
,
ph
);
make_multi_index
(
0
,
1
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
));
}
else
{
ignore
=
z_grid_buf
;
// P_dropped
blockwise_dropout
.
template
ApplyDropout
<
decltype
(
s_slash_p_thread_buf
),
true
>(
s_slash_p_thread_buf
,
ph
);
}
}
block_sync_lds
();
// wait for gemm1 LDS read
...
...
@@ -2126,9 +2133,9 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
kgrad_thread_copy_vgpr_to_global
.
MoveDstSliceWindow
(
kgrad_grid_desc_n0_o0_n1_o1_n2_o2_o3_o4
,
Gemm2
::
c_block_slice_copy_step
);
// step N
z_thread_copy_vgpr_to_global
.
MoveDstSliceWindow
(
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
,
make_multi_index
(
0
,
1
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
));
//
z_thread_copy_vgpr_to_global.MoveDstSliceWindow(
//
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
//
make_multi_index(0, 1, 0, 0, 0, 0, 0, 0, 0, 0));
}
while
(
++
gemm1_k_block_outer_index
<
num_gemm1_k_block_outer_loop
);
// end j loop
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment