Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
99436cd4
Commit
99436cd4
authored
Jul 20, 2024
by
danyao12
Browse files
save clear_tile
parent
b3100b6f
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
6 additions
and
8 deletions
+6
-8
include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr.hpp
...a/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr.hpp
+6
-8
No files found.
include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr.hpp
View file @
99436cd4
...
@@ -537,7 +537,6 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR
...
@@ -537,7 +537,6 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR
{
{
// STAGE 1, Q@K Gemm0
// STAGE 1, Q@K Gemm0
auto
st_acc
=
SPTBlockTileType
{};
auto
st_acc
=
SPTBlockTileType
{};
clear_tile
(
st_acc
);
q_block_tile
=
load_tile
(
q_dram_window
);
q_block_tile
=
load_tile
(
q_dram_window
);
move_tile_window
(
q_dram_window
,
{
kM0
,
0
});
move_tile_window
(
q_dram_window
,
{
kM0
,
0
});
...
@@ -551,7 +550,8 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR
...
@@ -551,7 +550,8 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR
d_block_tile
=
load_tile
(
d_dram_window
);
d_block_tile
=
load_tile
(
d_dram_window
);
move_tile_window
(
d_dram_window
,
{
kM0
});
move_tile_window
(
d_dram_window
,
{
kM0
});
gemm_0
(
st_acc
,
q_reg_tensor
,
k_reg_tensor
);
st_acc
=
gemm_0
(
q_reg_tensor
,
k_reg_tensor
);
auto
dot_reg_tensor
=
load_tile
(
dot_lds_read_window
);
auto
dot_reg_tensor
=
load_tile
(
dot_lds_read_window
);
HotLoopScheduler
::
template
GemmStagedScheduler
<
0
>();
HotLoopScheduler
::
template
GemmStagedScheduler
<
0
>();
...
@@ -670,9 +670,8 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR
...
@@ -670,9 +670,8 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR
__builtin_amdgcn_sched_barrier
(
0
);
__builtin_amdgcn_sched_barrier
(
0
);
// STAGE 4, OGrad@V Gemm2
// STAGE 4, OGrad@V Gemm2
auto
dpt_acc
=
SPGradTBlockTileType
{};
auto
dpt_acc
=
SPGradTBlockTileType
{};
clear_tile
(
dpt_acc
);
gemm_2
(
dpt_acc
,
do_reg_tensor
,
v_reg_tensor
);
dpt_acc
=
gemm_2
(
do_reg_tensor
,
v_reg_tensor
);
block_sync_lds
();
block_sync_lds
();
...
@@ -804,10 +803,9 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR
...
@@ -804,10 +803,9 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR
// Tail
// Tail
auto
st_acc
=
SPTBlockTileType
{};
auto
st_acc
=
SPTBlockTileType
{};
clear_tile
(
st_acc
);
// STAGE 1, Q@K Gemm0
// STAGE 1, Q@K Gemm0
gemm_0
(
st_acc
,
q_reg_tensor
,
k_reg_tensor
);
st_acc
=
gemm_0
(
q_reg_tensor
,
k_reg_tensor
);
// STAGE 2, Scale, Add bias, Mask, Softmax, Dropout
// STAGE 2, Scale, Add bias, Mask, Softmax, Dropout
if
constexpr
(
BiasEnum
==
BlockAttentionBiasEnum
::
ELEMENTWISE_BIAS
)
if
constexpr
(
BiasEnum
==
BlockAttentionBiasEnum
::
ELEMENTWISE_BIAS
)
...
@@ -919,10 +917,10 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR
...
@@ -919,10 +917,10 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR
// STAGE 4, OGrad@V Gemm2
// STAGE 4, OGrad@V Gemm2
auto
dpt_acc
=
SPGradTBlockTileType
{};
auto
dpt_acc
=
SPGradTBlockTileType
{};
clear_tile
(
dpt_acc
);
auto
qt_reg_tensor
=
load_tile
(
qt_lds_read_window
);
auto
qt_reg_tensor
=
load_tile
(
qt_lds_read_window
);
gemm_2
(
dpt_acc
,
do_reg_tensor
,
v_reg_tensor
);
dpt_acc
=
gemm_2
(
do_reg_tensor
,
v_reg_tensor
);
HotLoopScheduler
::
template
GemmStagedScheduler
<
2
>();
HotLoopScheduler
::
template
GemmStagedScheduler
<
2
>();
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment