Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
de53e421
Commit
de53e421
authored
Aug 17, 2023
by
letaoqin
Browse files
load d0 to lds
parent
ec2ad713
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
25 additions
and
4 deletions
+25
-4
include/ck/tensor_operation/gpu/grid/gridwise_batched_mha_bwd_xdl_cshuffle_qloop_b2t_v2.hpp
...id/gridwise_batched_mha_bwd_xdl_cshuffle_qloop_b2t_v2.hpp
+25
-4
No files found.
include/ck/tensor_operation/gpu/grid/gridwise_batched_mha_bwd_xdl_cshuffle_qloop_b2t_v2.hpp
View file @
de53e421
...
@@ -1229,7 +1229,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
...
@@ -1229,7 +1229,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
__device__
static
void
__device__
static
void
Run
(
const
InputDataType
*
__restrict__
p_q_grid
,
Run
(
const
InputDataType
*
__restrict__
p_q_grid
,
const
InputDataType
*
__restrict__
p_k_grid
,
const
InputDataType
*
__restrict__
p_k_grid
,
const
D0DataType
*
__restrict__
p_d_grid
,
const
D0DataType
*
__restrict__
p_d
0
_grid
,
ZDataType
*
__restrict__
p_z_grid
,
ZDataType
*
__restrict__
p_z_grid
,
const
InputDataType
*
__restrict__
p_v_grid
,
const
InputDataType
*
__restrict__
p_v_grid
,
const
InputDataType
*
__restrict__
p_y_grid
,
const
InputDataType
*
__restrict__
p_y_grid
,
...
@@ -1262,8 +1262,6 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
...
@@ -1262,8 +1262,6 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
const
index_t
raw_n_padded
,
const
index_t
raw_n_padded
,
const
index_t
block_idx_n
)
const
index_t
block_idx_n
)
{
{
ignore
=
d0_grid_desc_m0_n0_m1_m2_n1_m3
;
ignore
=
p_d_grid
;
const
FloatGemmAcc
p_dropout
=
type_convert
<
FloatGemmAcc
>
(
1.0
f
-
p_drop
);
const
FloatGemmAcc
p_dropout
=
type_convert
<
FloatGemmAcc
>
(
1.0
f
-
p_drop
);
const
FloatGemmAcc
rp_dropout
=
type_convert
<
FloatGemmAcc
>
(
1.0
f
/
p_dropout
);
const
FloatGemmAcc
rp_dropout
=
type_convert
<
FloatGemmAcc
>
(
1.0
f
/
p_dropout
);
const
ushort
p_dropout_in_16bits
=
const
ushort
p_dropout_in_16bits
=
...
@@ -1940,6 +1938,29 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
...
@@ -1940,6 +1938,29 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
block_sync_lds
();
// wait for lds read in gemm0 blockwise gemm
block_sync_lds
();
// wait for lds read in gemm0 blockwise gemm
// add bias
// add bias
if
constexpr
(
!
is_same
<
D0DataType
,
void
>::
value
)
{
const
auto
d0_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_d0_grid
,
d0_grid_desc_m0_n0_m1_m2_n1_m3
.
GetElementSpaceSize
());
auto
d0_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
static_cast
<
GemmDataType
*>
(
p_shared
)
+
SharedMemTrait
::
a_block_space_offset
,
D0Loader
::
d0_block_desc_m0_n0_m1_m2_n1_m3
.
GetElementSpaceSize
());
static_for
<
0
,
D0M1
,
1
>
{}([
&
](
auto
)
{
d0_block_copy_global_to_lds
.
RunRead
(
d0_grid_desc_m0_n0_m1_m2_n1_m3
,
d0_grid_buf
);
d0_block_copy_global_to_lds
.
MoveSrcSliceWindow
(
d0_grid_desc_m0_n0_m1_m2_n1_m3
,
make_multi_index
(
0
,
0
,
1
,
0
,
0
,
0
));
d0_block_copy_global_to_lds
.
RunWrite
(
D0Loader
::
d0_block_desc_m0_n0_m1_m2_n1_m3
,
d0_block_buf
);
});
d0_block_copy_global_to_lds
.
MoveSrcSliceWindow
(
d0_grid_desc_m0_n0_m1_m2_n1_m3
,
make_multi_index
(
1
,
0
,
-
D0M1
.
value
,
0
,
0
,
0
));
}
// P_i: = softmax(scalar * S_i:)
// P_i: = softmax(scalar * S_i:)
// scaling is already performed in the preceding statements with s_element_op
// scaling is already performed in the preceding statements with s_element_op
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment