Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
e1287b9a
"experiments/vscode:/vscode.git/clone" did not exist on "042fadfd7df22bd7734c68898c1c34598c8b2d77"
Commit
e1287b9a
authored
Jul 14, 2023
by
fsx950223
Browse files
skip dropout when dropout=0
parent
f5c70413
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
147 additions
and
131 deletions
+147
-131
include/ck/tensor_operation/gpu/device/impl/device_batched_mha_bwd_xdl_cshuffle_qloop_v1.hpp
...ice/impl/device_batched_mha_bwd_xdl_cshuffle_qloop_v1.hpp
+39
-35
include/ck/tensor_operation/gpu/device/impl/device_batched_mha_bwd_xdl_cshuffle_qloop_v2.hpp
...ice/impl/device_batched_mha_bwd_xdl_cshuffle_qloop_v2.hpp
+13
-5
include/ck/tensor_operation/gpu/grid/gridwise_batched_mha_bwd_xdl_cshuffle_qloop_b2t_v1.hpp
...id/gridwise_batched_mha_bwd_xdl_cshuffle_qloop_b2t_v1.hpp
+46
-44
include/ck/tensor_operation/gpu/grid/gridwise_batched_mha_bwd_xdl_cshuffle_qloop_b2t_v2.hpp
...id/gridwise_batched_mha_bwd_xdl_cshuffle_qloop_b2t_v2.hpp
+49
-47
No files found.
include/ck/tensor_operation/gpu/device/impl/device_batched_mha_bwd_xdl_cshuffle_qloop_v1.hpp
View file @
e1287b9a
...
@@ -48,6 +48,7 @@ template <typename GridwiseGemm,
...
@@ -48,6 +48,7 @@ template <typename GridwiseGemm,
typename
ComputeBasePtrOfStridedBatch
,
typename
ComputeBasePtrOfStridedBatch
,
typename
C0MatrixMask
,
typename
C0MatrixMask
,
bool
HasMainKBlockLoop
,
bool
HasMainKBlockLoop
,
bool
IsDropout
,
bool
Deterministic
>
bool
Deterministic
>
__global__
void
__global__
void
#if CK_USE_LAUNCH_BOUNDS
#if CK_USE_LAUNCH_BOUNDS
...
@@ -119,7 +120,7 @@ __global__ void
...
@@ -119,7 +120,7 @@ __global__ void
{
{
for
(
index_t
i
=
0
;
i
<
nblock
;
i
++
)
for
(
index_t
i
=
0
;
i
<
nblock
;
i
++
)
{
{
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
>(
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
,
IsDropout
>(
p_a_grid
+
a_batch_offset
,
p_a_grid
+
a_batch_offset
,
p_b_grid
+
b_batch_offset
,
p_b_grid
+
b_batch_offset
,
z_matrix_ptr
,
z_matrix_ptr
,
...
@@ -154,7 +155,7 @@ __global__ void
...
@@ -154,7 +155,7 @@ __global__ void
}
}
else
else
{
{
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
>(
p_a_grid
+
a_batch_offset
,
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
,
IsDropout
>(
p_a_grid
+
a_batch_offset
,
p_b_grid
+
b_batch_offset
,
p_b_grid
+
b_batch_offset
,
z_matrix_ptr
,
z_matrix_ptr
,
p_b1_grid
+
b1_batch_offset
,
p_b1_grid
+
b1_batch_offset
,
...
@@ -932,7 +933,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
...
@@ -932,7 +933,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
float
ave_time
=
0
;
float
ave_time
=
0
;
auto
launch_kernel
=
[
&
](
auto
has_main_k_block_loop_
)
{
auto
launch_kernel
=
[
&
](
auto
has_main_k_block_loop_
,
auto
is_dropout_
)
{
const
auto
kernel
=
const
auto
kernel
=
kernel_batched_multihead_attention_backward_qloop_xdl_cshuffle_v1
<
kernel_batched_multihead_attention_backward_qloop_xdl_cshuffle_v1
<
GridwiseGemm
,
GridwiseGemm
,
...
@@ -956,6 +957,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
...
@@ -956,6 +957,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
ComputeBasePtrOfStridedBatch
,
ComputeBasePtrOfStridedBatch
,
C0MatrixMask
,
C0MatrixMask
,
has_main_k_block_loop_
,
has_main_k_block_loop_
,
is_dropout_
,
Deterministic
>
;
Deterministic
>
;
return
launch_and_time_kernel
(
return
launch_and_time_kernel
(
...
@@ -997,9 +999,11 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
...
@@ -997,9 +999,11 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
arg
.
m_raw_padded_
,
arg
.
m_raw_padded_
,
arg
.
n_raw_padded_
);
arg
.
n_raw_padded_
);
};
};
if
(
arg
.
p_drop_
>
0.0
){
ave_time
=
launch_kernel
(
integral_constant
<
bool
,
false
>
{});
ave_time
=
launch_kernel
(
integral_constant
<
bool
,
false
>
{},
integral_constant
<
bool
,
true
>
{});
}
else
{
ave_time
=
launch_kernel
(
integral_constant
<
bool
,
false
>
{},
integral_constant
<
bool
,
false
>
{});
}
return
ave_time
;
return
ave_time
;
}
}
...
...
include/ck/tensor_operation/gpu/device/impl/device_batched_mha_bwd_xdl_cshuffle_qloop_v2.hpp
View file @
e1287b9a
...
@@ -47,6 +47,7 @@ template <typename GridwiseGemm,
...
@@ -47,6 +47,7 @@ template <typename GridwiseGemm,
typename
ComputeBasePtrOfStridedBatch
,
typename
ComputeBasePtrOfStridedBatch
,
typename
C0MatrixMask
,
typename
C0MatrixMask
,
bool
HasMainKBlockLoop
,
bool
HasMainKBlockLoop
,
bool
IsDropout
,
bool
Deterministic
>
bool
Deterministic
>
__global__
void
__global__
void
#if CK_USE_LAUNCH_BOUNDS
#if CK_USE_LAUNCH_BOUNDS
...
@@ -118,7 +119,7 @@ __global__ void
...
@@ -118,7 +119,7 @@ __global__ void
{
{
for
(
index_t
i
=
0
;
i
<
nblock
;
i
++
)
for
(
index_t
i
=
0
;
i
<
nblock
;
i
++
)
{
{
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
>(
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
,
IsDropout
>(
p_a_grid
+
a_batch_offset
,
p_a_grid
+
a_batch_offset
,
p_b_grid
+
b_batch_offset
,
p_b_grid
+
b_batch_offset
,
z_matrix_ptr
,
z_matrix_ptr
,
...
@@ -153,7 +154,7 @@ __global__ void
...
@@ -153,7 +154,7 @@ __global__ void
}
}
else
else
{
{
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
>(
p_a_grid
+
a_batch_offset
,
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
,
IsDropout
>(
p_a_grid
+
a_batch_offset
,
p_b_grid
+
b_batch_offset
,
p_b_grid
+
b_batch_offset
,
z_matrix_ptr
,
z_matrix_ptr
,
p_b1_grid
+
b1_batch_offset
,
p_b1_grid
+
b1_batch_offset
,
...
@@ -949,7 +950,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
...
@@ -949,7 +950,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
float
ave_time
=
0
;
float
ave_time
=
0
;
auto
launch_kernel
=
[
&
](
auto
has_main_k_block_loop_
)
{
auto
launch_kernel
=
[
&
](
auto
has_main_k_block_loop_
,
auto
is_dropout_
)
{
const
auto
kernel
=
const
auto
kernel
=
kernel_batched_multihead_attention_backward_qloop_xdl_cshuffle_v2
<
kernel_batched_multihead_attention_backward_qloop_xdl_cshuffle_v2
<
GridwiseGemm
,
GridwiseGemm
,
...
@@ -973,6 +974,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
...
@@ -973,6 +974,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
ComputeBasePtrOfStridedBatch
,
ComputeBasePtrOfStridedBatch
,
C0MatrixMask
,
C0MatrixMask
,
has_main_k_block_loop_
,
has_main_k_block_loop_
,
is_dropout_
,
Deterministic
>
;
Deterministic
>
;
return
launch_and_time_kernel
(
return
launch_and_time_kernel
(
...
@@ -1020,11 +1022,17 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
...
@@ -1020,11 +1022,17 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
if
(
GridwiseGemm
::
CalculateHasMainKBlockLoop
(
K
))
if
(
GridwiseGemm
::
CalculateHasMainKBlockLoop
(
K
))
{
{
ave_time
=
launch_kernel
(
integral_constant
<
bool
,
true
>
{});
if
(
arg
.
p_drop_
>
0.0
)
ave_time
=
launch_kernel
(
integral_constant
<
bool
,
true
>
{},
integral_constant
<
bool
,
true
>
{});
else
ave_time
=
launch_kernel
(
integral_constant
<
bool
,
true
>
{},
integral_constant
<
bool
,
false
>
{});
}
}
else
else
{
{
ave_time
=
launch_kernel
(
integral_constant
<
bool
,
false
>
{});
if
(
arg
.
p_drop_
>
0.0
)
ave_time
=
launch_kernel
(
integral_constant
<
bool
,
false
>
{},
integral_constant
<
bool
,
true
>
{});
else
ave_time
=
launch_kernel
(
integral_constant
<
bool
,
false
>
{},
integral_constant
<
bool
,
false
>
{});
}
}
return
ave_time
;
return
ave_time
;
...
...
include/ck/tensor_operation/gpu/grid/gridwise_batched_mha_bwd_xdl_cshuffle_qloop_b2t_v1.hpp
View file @
e1287b9a
...
@@ -1222,6 +1222,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
...
@@ -1222,6 +1222,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
}
}
template
<
bool
HasMainKBlockLoop
,
template
<
bool
HasMainKBlockLoop
,
bool
IsDropout
,
typename
Block2CTileMap
,
typename
Block2CTileMap
,
typename
C0MatrixMask
,
typename
C0MatrixMask
,
typename
YGradGridDesc_O0_M_O1
>
typename
YGradGridDesc_O0_M_O1
>
...
@@ -1947,6 +1948,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
...
@@ -1947,6 +1948,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
constexpr
auto
position_offset
=
M3
*
M4
;
constexpr
auto
position_offset
=
M3
*
M4
;
// save z to global
// save z to global
if
constexpr
(
IsDropout
){
if
(
p_z_grid
)
if
(
p_z_grid
)
{
{
...
@@ -1996,7 +1998,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
...
@@ -1996,7 +1998,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
true
>(
true
>(
s_slash_p_thread_buf
,
ph
,
global_elem_id
,
raw_n_padded
);
s_slash_p_thread_buf
,
ph
,
global_elem_id
,
raw_n_padded
);
}
}
}
block_sync_lds
();
// wait for gemm1 LDS read
block_sync_lds
();
// wait for gemm1 LDS read
// dS = P * (dP - Y_dot_dY)
// dS = P * (dP - Y_dot_dY)
...
...
include/ck/tensor_operation/gpu/grid/gridwise_batched_mha_bwd_xdl_cshuffle_qloop_b2t_v2.hpp
View file @
e1287b9a
...
@@ -1154,6 +1154,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
...
@@ -1154,6 +1154,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
}
}
template
<
bool
HasMainKBlockLoop
,
template
<
bool
HasMainKBlockLoop
,
bool
IsDropout
,
typename
Block2CTileMap
,
typename
Block2CTileMap
,
typename
C0MatrixMask
,
typename
C0MatrixMask
,
typename
YGradGridDesc_M0_O_M1
>
typename
YGradGridDesc_M0_O_M1
>
...
@@ -1863,6 +1864,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
...
@@ -1863,6 +1864,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
constexpr
auto
position_offset
=
M3
*
M4
;
constexpr
auto
position_offset
=
M3
*
M4
;
// save z to global
// save z to global
if
constexpr
(
IsDropout
){
if
(
p_z_grid
)
if
(
p_z_grid
)
{
{
...
@@ -1911,7 +1913,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
...
@@ -1911,7 +1913,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
true
>(
true
>(
s_slash_p_thread_buf
,
ph
,
global_elem_id
,
raw_n_padded
);
s_slash_p_thread_buf
,
ph
,
global_elem_id
,
raw_n_padded
);
}
}
}
block_sync_lds
();
// wait for gemm1 LDS read
block_sync_lds
();
// wait for gemm1 LDS read
// gemm dV
// gemm dV
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment