Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
711663bc
Commit
711663bc
authored
Jan 16, 2023
by
qin letao
Browse files
add P and ds dropout
parent
e3a2651b
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
25 additions
and
5 deletions
+25
-5
include/ck/tensor_operation/gpu/block/blockwise_dropout.hpp
include/ck/tensor_operation/gpu/block/blockwise_dropout.hpp
+3
-2
include/ck/tensor_operation/gpu/device/impl/device_batched_multihead_attention_backward_xdl_cshuffle.hpp
...ice_batched_multihead_attention_backward_xdl_cshuffle.hpp
+1
-0
include/ck/tensor_operation/gpu/grid/gridwise_batched_multihead_attention_backward_xdl_cshuffle_v2.hpp
..._batched_multihead_attention_backward_xdl_cshuffle_v2.hpp
+21
-3
No files found.
include/ck/tensor_operation/gpu/block/blockwise_dropout.hpp
View file @
711663bc
...
@@ -16,7 +16,7 @@ struct BlockwiseDropout
...
@@ -16,7 +16,7 @@ struct BlockwiseDropout
static
constexpr
index_t
MRepeat
=
ThreadSliceDesc_M_K
{}.
GetLength
(
I0
);
static
constexpr
index_t
MRepeat
=
ThreadSliceDesc_M_K
{}.
GetLength
(
I0
);
static
constexpr
index_t
KRepeat
=
ThreadSliceDesc_M_K
{}.
GetLength
(
I1
);
static
constexpr
index_t
KRepeat
=
ThreadSliceDesc_M_K
{}.
GetLength
(
I1
);
template
<
typename
CThreadBuffer
>
template
<
typename
CThreadBuffer
,
bool
using_sign_bit
=
false
>
__host__
__device__
void
ApplyDropout
(
CThreadBuffer
&
in_thread_buf
,
__host__
__device__
void
ApplyDropout
(
CThreadBuffer
&
in_thread_buf
,
ck
::
philox
ph
,
ck
::
philox
ph
,
const
int
repeat_index
,
const
int
repeat_index
,
...
@@ -24,7 +24,8 @@ struct BlockwiseDropout
...
@@ -24,7 +24,8 @@ struct BlockwiseDropout
{
{
auto
execute_dropout
=
[
&
](
bool
keep
,
DataType
val
)
{
auto
execute_dropout
=
[
&
](
bool
keep
,
DataType
val
)
{
return
keep
?
val
*
p_dropout_rescale
:
float
(
0
);
return
keep
?
val
*
p_dropout_rescale
:
(
using_sign_bit
?
-
val
*
p_dropout_rescale
:
float
(
0
));
};
};
constexpr
int
tmp_size
=
MRepeat
*
KRepeat
;
constexpr
int
tmp_size
=
MRepeat
*
KRepeat
;
...
...
include/ck/tensor_operation/gpu/device/impl/device_batched_multihead_attention_backward_xdl_cshuffle.hpp
View file @
711663bc
...
@@ -130,6 +130,7 @@ __global__ void
...
@@ -130,6 +130,7 @@ __global__ void
block_2_ctile_map
,
block_2_ctile_map
,
c0_matrix_mask
,
c0_matrix_mask
,
p_dropout_in_16bits
,
p_dropout_in_16bits
,
p_dropout
,
rp_dropout
,
rp_dropout
,
ph
);
ph
);
#else
#else
...
...
include/ck/tensor_operation/gpu/grid/gridwise_batched_multihead_attention_backward_xdl_cshuffle_v2.hpp
View file @
711663bc
...
@@ -16,6 +16,7 @@
...
@@ -16,6 +16,7 @@
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_softmax.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_softmax.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_dropout.hpp"
namespace
ck
{
namespace
ck
{
...
@@ -1121,6 +1122,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
...
@@ -1121,6 +1122,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
const
Block2CTileMap
&
block_2_ctile_map
,
const
Block2CTileMap
&
block_2_ctile_map
,
const
C0MatrixMask
&
c0_matrix_mask
,
const
C0MatrixMask
&
c0_matrix_mask
,
const
ushort
p_dropout_in_16bits
,
const
ushort
p_dropout_in_16bits
,
FloatGemmAcc
p_dropout
,
FloatGemmAcc
rp_dropout
,
FloatGemmAcc
rp_dropout
,
ck
::
philox
&
ph
)
ck
::
philox
&
ph
)
{
{
...
@@ -1357,6 +1359,9 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
...
@@ -1357,6 +1359,9 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
decltype
(
thread_cluster_desc_m_n
),
decltype
(
thread_cluster_desc_m_n
),
decltype
(
thread_slice_desc_m_n
)
>
{};
decltype
(
thread_slice_desc_m_n
)
>
{};
auto
blockwise_dropout
=
BlockwiseDropout
<
FloatGemmAcc
,
decltype
(
thread_slice_desc_m_n
)
>
{
p_dropout_in_16bits
,
rp_dropout
};
auto
lse_grid_desc_mblock_mrepeat_mwave_mperxdl
=
auto
lse_grid_desc_mblock_mrepeat_mwave_mperxdl
=
MakeLSEGridDescriptor_MBlock_MRepeat_NWave_MPerXdl
(
lse_grid_desc_m
);
MakeLSEGridDescriptor_MBlock_MRepeat_NWave_MPerXdl
(
lse_grid_desc_m
);
...
@@ -1600,7 +1605,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
...
@@ -1600,7 +1605,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
static_for
<
0
,
YDotYGrad_M_O
::
ThreadSliceLength_M
,
1
>
{}([
&
](
auto
iM
)
{
static_for
<
0
,
YDotYGrad_M_O
::
ThreadSliceLength_M
,
1
>
{}([
&
](
auto
iM
)
{
const
auto
idx_on_block
=
y_thread_data_on_block_idx
[
I1
]
+
iM
;
const
auto
idx_on_block
=
y_thread_data_on_block_idx
[
I1
]
+
iM
;
y_dot_ygrad_block_accum_buf
.
AtomicAdd
(
y_dot_ygrad_block_accum_buf
.
AtomicAdd
(
idx_on_block
,
true
,
y_dot_ygrad_thread_accum_buf
[
iM
]
);
idx_on_block
,
true
,
y_dot_ygrad_thread_accum_buf
[
iM
]
*
p_dropout
);
// p_dropoutD1
});
});
block_sync_lds
();
block_sync_lds
();
...
@@ -1717,6 +1722,10 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
...
@@ -1717,6 +1722,10 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
// scaling is already performed in the preceding statements with s_element_op
// scaling is already performed in the preceding statements with s_element_op
blockwise_softmax
.
RunWithPreCalcStats
(
s_slash_p_thread_buf
,
lse_thread_buf
);
blockwise_softmax
.
RunWithPreCalcStats
(
s_slash_p_thread_buf
,
lse_thread_buf
);
// P_dropped
blockwise_dropout
.
template
ApplyDropout
<
decltype
(
s_slash_p_thread_buf
),
true
>(
s_slash_p_thread_buf
,
ph
,
gemm1_k_block_outer_index
,
num_gemm1_k_block_outer_loop
);
block_sync_lds
();
// wait for gemm1 LDS read
block_sync_lds
();
// wait for gemm1 LDS read
SubThreadBlock
<
BlockSize
>
gemm2_a_copy_subgroup
(
s_blockwise_gemm
.
GetWaveIdx
()[
I0
],
SubThreadBlock
<
BlockSize
>
gemm2_a_copy_subgroup
(
s_blockwise_gemm
.
GetWaveIdx
()[
I0
],
...
@@ -1807,8 +1816,17 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
...
@@ -1807,8 +1816,17 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
constexpr
auto
m
=
constexpr
auto
m
=
pgrad_thread_idx_to_m_n_adaptor
.
CalculateBottomIndex
(
pgrad_thread_idx
)[
I0
];
pgrad_thread_idx_to_m_n_adaptor
.
CalculateBottomIndex
(
pgrad_thread_idx
)[
I0
];
// dS and P has same thread buf layout
// dS and P has same thread buf layout
sgrad_thread_buf
(
i
)
=
s_slash_p_thread_buf
[
i
]
*
if
(
s_slash_p_thread_buf
[
i
]
>=
0
)
{
sgrad_thread_buf
(
i
)
=
s_slash_p_thread_buf
[
i
]
*
(
pgrad_thread_buf
[
i
]
-
y_dot_ygrad_thread_buf
[
Number
<
m
>
{}]);
(
pgrad_thread_buf
[
i
]
-
y_dot_ygrad_thread_buf
[
Number
<
m
>
{}]);
}
else
{
sgrad_thread_buf
(
i
)
=
s_slash_p_thread_buf
[
i
]
*
y_dot_ygrad_thread_buf
[
Number
<
m
>
{}];
}
});
});
// gemm dQ
// gemm dQ
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment