Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
3d5b0755
Commit
3d5b0755
authored
Aug 02, 2024
by
danyao12
Browse files
non-iglp pipeline for headdim padding cases
parent
f8b14618
Changes
6
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
1107 additions
and
315 deletions
+1107
-315
example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py
example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py
+12
-9
include/ck_tile/ops/fmha.hpp
include/ck_tile/ops/fmha.hpp
+1
-0
include/ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp
include/ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp
+13
-7
include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr.hpp
...a/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr.hpp
+42
-298
include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr_iglp.hpp
...eline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr_iglp.hpp
+1037
-0
include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_enum.hpp
...k_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_enum.hpp
+2
-1
No files found.
example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py
View file @
3d5b0755
...
@@ -14,11 +14,13 @@ from codegen.cpp_symbol_map import *
...
@@ -14,11 +14,13 @@ from codegen.cpp_symbol_map import *
BWD_DQDKDV_PIPELINE_MAP
=
{
BWD_DQDKDV_PIPELINE_MAP
=
{
"kr_ktr_vr"
:
"ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR"
,
"kr_ktr_vr_iglp"
:
"ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP"
,
"kr_ktr_vr"
:
"ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR"
,
}
}
BWD_DQDKDV_PIPELINE_ENUM_MAP
=
{
BWD_DQDKDV_PIPELINE_ENUM_MAP
=
{
"kr_ktr_vr"
:
"ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR"
,
"kr_ktr_vr_iglp"
:
"ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP"
,
"kr_ktr_vr"
:
"ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR"
,
}
}
FMHA_BWD_KERNEL_HEADER
=
"""// SPDX-License-Identifier: MIT
FMHA_BWD_KERNEL_HEADER
=
"""// SPDX-License-Identifier: MIT
...
@@ -408,7 +410,7 @@ class FmhaBwdDQDKDVKernel:
...
@@ -408,7 +410,7 @@ class FmhaBwdDQDKDVKernel:
if
n
!=
''
:
n
=
'p'
+
n
if
n
!=
''
:
n
=
'p'
+
n
return
n
return
n
pn
=
pad_name
()
pn
=
pad_name
()
n
=
f
"fmha_bwd_d
{
self
.
F_hdim
}
_
{
self
.
F_dtype
}
_
{
self
.
F_mode
}
_"
+
self
.
F_tile
.
name
n
=
f
"fmha_bwd_d
{
self
.
F_hdim
}
_
{
self
.
F_dtype
}
_
{
self
.
F_mode
}
_"
+
self
.
F_tile
.
name
+
f
'_
{
self
.
F_pipeline
}
'
if
pn
!=
''
:
n
+=
f
'_
{
pn
}
'
if
pn
!=
''
:
n
+=
f
'_
{
pn
}
'
if
self
.
F_bias
!=
'no'
:
n
+=
f
'_
{
self
.
F_bias
}
'
if
self
.
F_bias
!=
'no'
:
n
+=
f
'_
{
self
.
F_bias
}
'
if
self
.
F_dbias
==
't'
:
n
+=
'_dbias'
if
self
.
F_dbias
==
't'
:
n
+=
'_dbias'
...
@@ -450,13 +452,13 @@ def get_fmha_bwd_dq_dk_dv_tile_ppl_dict_from_dtype(dtype : str) -> Optional[dict
...
@@ -450,13 +452,13 @@ def get_fmha_bwd_dq_dk_dv_tile_ppl_dict_from_dtype(dtype : str) -> Optional[dict
if
dtype
==
'fp16'
or
dtype
==
'bf16'
:
if
dtype
==
'fp16'
or
dtype
==
'bf16'
:
return
{
return
{
'32'
:
[
FmhaBwdDQDKDVTileSize
(
32
,
128
,
32
,
32
,
32
,
32
,
64
,
32
,
32
,
1
,
4
,
1
,
4
,
1
,
1
,
2
,
2
,
1
,
16
,
16
,
32
,
16
,
16
,
16
,
1
),
'32'
:
[
FmhaBwdDQDKDVTileSize
(
32
,
128
,
32
,
32
,
32
,
32
,
64
,
32
,
32
,
1
,
4
,
1
,
4
,
1
,
1
,
2
,
2
,
1
,
16
,
16
,
32
,
16
,
16
,
16
,
1
),
"kr_ktr_vr"
],
"kr_ktr_vr_iglp"
,
"kr_ktr_vr"
],
'64'
:
[
FmhaBwdDQDKDVTileSize
(
32
,
128
,
64
,
32
,
64
,
32
,
32
,
64
,
64
,
1
,
4
,
1
,
4
,
1
,
1
,
1
,
4
,
1
,
16
,
16
,
32
,
16
,
16
,
16
,
1
),
'64'
:
[
FmhaBwdDQDKDVTileSize
(
32
,
128
,
64
,
32
,
64
,
32
,
32
,
64
,
64
,
1
,
4
,
1
,
4
,
1
,
1
,
1
,
4
,
1
,
16
,
16
,
32
,
16
,
16
,
16
,
1
),
"kr_ktr_vr"
],
"kr_ktr_vr_iglp"
,
"kr_ktr_vr"
],
'128'
:
[
FmhaBwdDQDKDVTileSize
(
16
,
128
,
128
,
16
,
128
,
16
,
32
,
128
,
128
,
1
,
4
,
1
,
4
,
1
,
1
,
1
,
4
,
1
,
16
,
16
,
32
,
16
,
16
,
16
,
1
),
'128'
:
[
FmhaBwdDQDKDVTileSize
(
16
,
128
,
128
,
16
,
128
,
16
,
32
,
128
,
128
,
1
,
4
,
1
,
4
,
1
,
1
,
1
,
4
,
1
,
16
,
16
,
32
,
16
,
16
,
16
,
1
),
"kr_ktr_vr"
],
"kr_ktr_vr_iglp"
,
"kr_ktr_vr"
],
'256'
:
[
FmhaBwdDQDKDVTileSize
(
16
,
64
,
256
,
16
,
256
,
16
,
32
,
256
,
256
,
1
,
4
,
1
,
4
,
1
,
1
,
1
,
4
,
1
,
16
,
16
,
32
,
16
,
16
,
16
,
1
),
'256'
:
[
FmhaBwdDQDKDVTileSize
(
16
,
64
,
256
,
16
,
256
,
16
,
32
,
256
,
256
,
1
,
4
,
1
,
4
,
1
,
1
,
1
,
4
,
1
,
16
,
16
,
32
,
16
,
16
,
16
,
1
),
"kr_ktr_vr"
]
"kr_ktr_vr_iglp"
,
"kr_ktr_vr"
]
}
}
else
:
else
:
return
None
return
None
...
@@ -481,6 +483,8 @@ def get_bwd_dq_dk_dv_blobs(kernel_filter : Optional[str], receipt, mask_impl) ->
...
@@ -481,6 +483,8 @@ def get_bwd_dq_dk_dv_blobs(kernel_filter : Optional[str], receipt, mask_impl) ->
continue
continue
if
(
"wg32"
in
dropout
):
if
(
"wg32"
in
dropout
):
continue
continue
if
(
dpad
==
"t"
or
dvpad
==
"t"
):
ppl
=
d
[
hdim_str
][
2
]
k
=
FmhaBwdDQDKDVKernel
(
F_idx
=
0
,
F_hdim
=
hdim
,
F_dtype
=
dtype
,
F_tile
=
tile
,
k
=
FmhaBwdDQDKDVKernel
(
F_idx
=
0
,
F_hdim
=
hdim
,
F_dtype
=
dtype
,
F_tile
=
tile
,
F_spad
=
spad
,
F_skpad
=
skpad
,
F_dpad
=
dpad
,
F_dvpad
=
dvpad
,
F_spad
=
spad
,
F_skpad
=
skpad
,
F_dpad
=
dpad
,
F_dvpad
=
dvpad
,
F_bias
=
bias
,
F_dbias
=
dbias
,
F_dropout
=
dropout
,
F_mask
=
mask
,
F_mode
=
mode
,
F_bias
=
bias
,
F_dbias
=
dbias
,
F_dropout
=
dropout
,
F_mask
=
mask
,
F_mode
=
mode
,
...
@@ -497,8 +501,7 @@ def get_bwd_dq_dk_dv_blobs(kernel_filter : Optional[str], receipt, mask_impl) ->
...
@@ -497,8 +501,7 @@ def get_bwd_dq_dk_dv_blobs(kernel_filter : Optional[str], receipt, mask_impl) ->
if
receipt
==
3
:
if
receipt
==
3
:
cond
=
dtype
in
[
'fp16'
,
'bf16'
]
cond
=
dtype
in
[
'fp16'
,
'bf16'
]
cond
&=
bias
in
[
'no'
,
'alibi'
]
cond
&=
bias
in
[
'no'
,
'alibi'
]
cond
&=
dpad
==
"f"
cond
&=
dpad
==
dvpad
cond
&=
dvpad
==
"f"
cond
&=
deterministic
==
"f"
cond
&=
deterministic
==
"f"
if
not
cond
:
if
not
cond
:
continue
continue
...
...
include/ck_tile/ops/fmha.hpp
View file @
3d5b0755
...
@@ -17,6 +17,7 @@
...
@@ -17,6 +17,7 @@
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_convert_dq.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_convert_dq.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dot_do_o.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dot_do_o.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr_iglp.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_enum.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_enum.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_problem.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_problem.hpp"
...
...
include/ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp
View file @
3d5b0755
...
@@ -72,9 +72,12 @@ struct FmhaBwdDQDKDVKernel
...
@@ -72,9 +72,12 @@ struct FmhaBwdDQDKDVKernel
{
{
// sync with generate.py
// sync with generate.py
// clang-format off
// clang-format off
using
bfs
=
typename
FmhaPipeline
::
BlockFmhaShape
;
using
bfs
=
typename
FmhaPipeline
::
BlockFmhaShape
;
using
gbr
=
typename
bfs
::
Gemm0BlockWarps
;
using
gbr0
=
typename
bfs
::
Gemm0BlockWarps
;
using
gwt
=
typename
bfs
::
Gemm0WarpTile
;
using
gbr1
=
typename
bfs
::
Gemm1BlockWarps
;
using
gbr4
=
typename
bfs
::
Gemm4BlockWarps
;
using
gwt0
=
typename
bfs
::
Gemm0WarpTile
;
using
gwt1
=
typename
bfs
::
Gemm1WarpTile
;
#define _SS_ std::string
#define _SS_ std::string
#define _TS_ std::to_string
#define _TS_ std::to_string
auto
pn
=
[
&
]
()
{
auto
pn
=
[
&
]
()
{
...
@@ -87,10 +90,13 @@ struct FmhaBwdDQDKDVKernel
...
@@ -87,10 +90,13 @@ struct FmhaBwdDQDKDVKernel
return
return
_SS_
(
"fmha_bwd_d"
)
+
_TS_
(
bfs
::
kQKHeaddim
)
+
"_"
+
_SS_
(
t2s
<
QDataType
>::
name
)
+
_SS_
(
"fmha_bwd_d"
)
+
_TS_
(
bfs
::
kQKHeaddim
)
+
"_"
+
_SS_
(
t2s
<
QDataType
>::
name
)
+
"_"
+
(
kIsGroupMode
?
"group"
:
"batch"
)
+
"_"
+
"_"
+
(
kIsGroupMode
?
"group"
:
"batch"
)
+
"_"
+
"b"
+
_TS_
(
bfs
::
kM0
)
+
"x"
+
_TS_
(
bfs
::
kN0
)
+
"x"
+
_TS_
(
bfs
::
kK0
)
+
"x"
+
"b"
+
_TS_
(
bfs
::
kM0
)
+
"x"
+
_TS_
(
bfs
::
kN0
)
+
"x"
+
_TS_
(
bfs
::
kK0
)
+
"x"
+
_TS_
(
bfs
::
kK1
)
+
"x"
+
_TS_
(
bfs
::
kK2
)
+
"x"
+
_TS_
(
bfs
::
kK3
)
+
"x"
+
_TS_
(
bfs
::
kQKHeaddim
)
+
"x"
+
_TS_
(
bfs
::
kVHeaddim
)
+
"_"
+
_TS_
(
bfs
::
kK4
)
+
"x"
+
_TS_
(
bfs
::
kQKHeaddim
)
+
"x"
+
_TS_
(
bfs
::
kVHeaddim
)
+
"_"
+
"r"
+
_TS_
(
gbr
::
at
(
ck_tile
::
number
<
0
>
{}))
+
"x"
+
_TS_
(
gbr
::
at
(
ck_tile
::
number
<
1
>
{}))
+
"x"
+
_TS_
(
gbr
::
at
(
ck_tile
::
number
<
2
>
{}))
+
"_"
+
"r"
+
_TS_
(
gbr0
::
at
(
ck_tile
::
number
<
0
>
{}))
+
"x"
+
_TS_
(
gbr0
::
at
(
ck_tile
::
number
<
1
>
{}))
+
"x"
+
_TS_
(
gbr0
::
at
(
ck_tile
::
number
<
2
>
{}))
+
"_"
+
"w"
+
_TS_
(
gwt
::
at
(
ck_tile
::
number
<
0
>
{}))
+
"x"
+
_TS_
(
gwt
::
at
(
ck_tile
::
number
<
1
>
{}))
+
"x"
+
_TS_
(
gwt
::
at
(
ck_tile
::
number
<
2
>
{}))
+
"_"
+
"r"
+
_TS_
(
gbr1
::
at
(
ck_tile
::
number
<
0
>
{}))
+
"x"
+
_TS_
(
gbr1
::
at
(
ck_tile
::
number
<
1
>
{}))
+
"x"
+
_TS_
(
gbr1
::
at
(
ck_tile
::
number
<
2
>
{}))
+
"_"
+
"r"
+
_TS_
(
gbr4
::
at
(
ck_tile
::
number
<
0
>
{}))
+
"x"
+
_TS_
(
gbr4
::
at
(
ck_tile
::
number
<
1
>
{}))
+
"x"
+
_TS_
(
gbr4
::
at
(
ck_tile
::
number
<
2
>
{}))
+
"_"
+
"w"
+
_TS_
(
gwt0
::
at
(
ck_tile
::
number
<
0
>
{}))
+
"x"
+
_TS_
(
gwt0
::
at
(
ck_tile
::
number
<
1
>
{}))
+
"x"
+
_TS_
(
gwt0
::
at
(
ck_tile
::
number
<
2
>
{}))
+
"_"
+
"w"
+
_TS_
(
gwt1
::
at
(
ck_tile
::
number
<
0
>
{}))
+
"x"
+
_TS_
(
gwt1
::
at
(
ck_tile
::
number
<
1
>
{}))
+
"x"
+
_TS_
(
gwt1
::
at
(
ck_tile
::
number
<
2
>
{}))
+
"_"
+
(
"o"
+
_TS_
(
kBlockPerCu
)
+
"_"
)
+
_SS_
(
FmhaPipeline
::
name
)
+
(
pn
.
empty
()
?
""
:
"_"
+
pn
)
+
(
"o"
+
_TS_
(
kBlockPerCu
)
+
"_"
)
+
_SS_
(
FmhaPipeline
::
name
)
+
(
pn
.
empty
()
?
""
:
"_"
+
pn
)
+
(
BiasEnum
==
BlockAttentionBiasEnum
::
NO_BIAS
?
_SS_
(
""
)
:
(
_SS_
(
"_"
)
+
BlockAttentionBiasEnumToStr
<
BiasEnum
>::
name
))
+
(
BiasEnum
==
BlockAttentionBiasEnum
::
NO_BIAS
?
_SS_
(
""
)
:
(
_SS_
(
"_"
)
+
BlockAttentionBiasEnumToStr
<
BiasEnum
>::
name
))
+
(
kHasBiasGrad
?
"_dbias"
:
""
)
+
(
kHasMask
?
"_"
+
_SS_
(
FmhaMask
::
name
)
:
""
)
+
(
kHasDropout
?
"_dropout"
:
""
)
+
(
kHasBiasGrad
?
"_dbias"
:
""
)
+
(
kHasMask
?
"_"
+
_SS_
(
FmhaMask
::
name
)
:
""
)
+
(
kHasDropout
?
"_dropout"
:
""
)
+
...
...
include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr.hpp
View file @
3d5b0755
...
@@ -488,73 +488,37 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR
...
@@ -488,73 +488,37 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR
static_assert
(
kM0
==
kK3
,
"kM0 should equal to kK3"
);
static_assert
(
kM0
==
kK3
,
"kM0 should equal to kK3"
);
constexpr
index_t
k4_loops
=
kN0
/
kK4
;
constexpr
index_t
k4_loops
=
kN0
/
kK4
;
/*
* Prefetch Q, LSE, dO, D
*/
auto
q_block_tile
=
load_tile
(
q_dram_window
);
move_tile_window
(
q_dram_window
,
{
kM0
,
0
});
auto
lse_block_tile
=
load_tile
(
lse_dram_window
);
move_tile_window
(
lse_dram_window
,
{
kM0
});
auto
do_block_tile
=
load_tile
(
do_dram_window
);
move_tile_window
(
do_dram_window
,
{
kM0
,
0
});
auto
d_block_tile
=
load_tile
(
d_dram_window
);
move_tile_window
(
d_dram_window
,
{
kM0
});
/*
* Store prefetched data into LDS
*/
store_tile
(
q_lds_window
,
q_block_tile
);
shuffle_tile
(
qt_block_tile
,
q_block_tile
);
store_tile
(
qt_lds_write_window
,
qt_block_tile
);
store_tile
(
lse_lds_write_window
,
lse_block_tile
);
store_tile
(
do_lds_window
,
do_block_tile
);
shuffle_tile
(
dot_block_tile
,
do_block_tile
);
store_tile
(
dot_lds_write_window
,
dot_block_tile
);
store_tile
(
d_lds_write_window
,
d_block_tile
);
block_sync_lds
();
/*
* Prefetch LDS data into Reg to Asynchronous Data Movement and MFMA pipeline
*/
auto
q_reg_tensor
=
load_tile
(
q_lds_read_window
);
auto
lse
=
load_tile
(
lse_lds_read_window
);
auto
do_reg_tensor
=
load_tile
(
do_lds_read_window
);
auto
d
=
load_tile
(
d_lds_read_window
);
clear_tile
(
dv_acc
);
clear_tile
(
dv_acc
);
clear_tile
(
dk_acc
);
clear_tile
(
dk_acc
);
__builtin_amdgcn_sched_barrier
(
0
);
__builtin_amdgcn_sched_barrier
(
0
);
// Hot loop
// Hot loop
while
(
i_total_loops
<
(
num_total_loop
-
1
)
)
while
(
i_total_loops
<
num_total_loop
)
{
{
// STAGE 1, Q@K Gemm0
auto
q_block_tile
=
load_tile
(
q_dram_window
);
auto
st_acc
=
SPTBlockTileType
{};
q_block_tile
=
load_tile
(
q_dram_window
);
move_tile_window
(
q_dram_window
,
{
kM0
,
0
});
move_tile_window
(
q_dram_window
,
{
kM0
,
0
});
lse_block_tile
=
load_tile
(
lse_dram_window
);
auto
lse_block_tile
=
load_tile
(
lse_dram_window
);
move_tile_window
(
lse_dram_window
,
{
kM0
});
move_tile_window
(
lse_dram_window
,
{
kM0
});
do_block_tile
=
load_tile
(
do_dram_window
);
store_tile
(
q_lds_window
,
q_block_tile
);
move_tile_window
(
do_dram_window
,
{
kM0
,
0
});
shuffle_tile
(
qt_block_tile
,
q_block_tile
);
store_tile
(
qt_lds_write_window
,
qt_block_tile
);
d_block_tile
=
load_tile
(
d_dram_window
);
store_tile
(
lse_lds_write_window
,
lse_block_tile
);
move_tile_window
(
d_dram_window
,
{
kM0
});
st_acc
=
gemm_0
(
q_reg_tensor
,
k_reg_tensor
);
block_sync_lds
(
);
auto
dot_reg_tensor
=
load_tile
(
dot_lds_read_window
);
auto
q_reg_tensor
=
load_tile
(
q_lds_read_window
);
auto
lse
=
load_tile
(
lse_lds_read_window
);
block_sync_lds
();
// STAGE 1, Q@K Gemm0
auto
st_acc
=
SPTBlockTileType
{};
st_acc
=
gemm_0
(
q_reg_tensor
,
k_reg_tensor
);
HotLoopScheduler
::
template
GemmStagedScheduler
<
0
>();
__builtin_amdgcn_sched_barrier
(
0
);
// STAGE 2, Scale, Add bias, Mask, Softmax, Dropout
// STAGE 2, Scale, Add bias, Mask, Softmax, Dropout
if
constexpr
(
BiasEnum
==
BlockAttentionBiasEnum
::
ELEMENTWISE_BIAS
)
if
constexpr
(
BiasEnum
==
BlockAttentionBiasEnum
::
ELEMENTWISE_BIAS
)
{
{
...
@@ -660,36 +624,38 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR
...
@@ -660,36 +624,38 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR
}();
}();
// STAGE 3, P^T@OGrad^T Gemm1
// STAGE 3, P^T@OGrad^T Gemm1
Policy
::
template
PTFromGemm0CToGemm1A
<
Problem
,
auto
do_block_tile
=
load_tile
(
do_dram_window
);
decltype
(
pt_reg_tensor
),
move_tile_window
(
do_dram_window
,
{
kM0
,
0
});
decltype
(
pt_gemm
)>(
pt_reg_tensor
,
pt_gemm
);
gemm_1
(
dv_acc
,
pt_reg_tensor
,
dot_reg_tensor
);
auto
qt_reg_tensor
=
load_tile
(
qt_lds_read_window
);
auto
d_block_tile
=
load_tile
(
d_dram_window
);
move_tile_window
(
d_dram_window
,
{
kM0
});
HotLoopScheduler
::
template
GemmStagedScheduler
<
1
>();
store_tile
(
do_lds_window
,
do_block_tile
);
__builtin_amdgcn_sched_barrier
(
0
);
shuffle_tile
(
dot_block_tile
,
do_block_tile
);
// STAGE 4, OGrad@V Gemm2
store_tile
(
dot_lds_write_window
,
dot_block_tile
);
auto
dpt_acc
=
SPGradTBlockTileType
{};
dpt_acc
=
gemm_2
(
do_reg_tensor
,
v_reg_tensor
);
store_tile
(
d_lds_write_window
,
d_block_tile
);
block_sync_lds
();
block_sync_lds
();
store_tile
(
q_lds_window
,
q_block_tile
);
auto
dot_reg_tensor
=
load_tile
(
dot_lds_read_window
);
shuffle_tile
(
qt_block_tile
,
q_block_tile
);
store_tile
(
qt_lds_write_window
,
qt_block_tile
);
store_tile
(
lse_lds_write_window
,
lse_block_tile
);
block_sync_lds
(
);
store_tile
(
do_lds_window
,
do_block_tile
);
Policy
::
template
PTFromGemm0CToGemm1A
<
Problem
,
shuffle_tile
(
dot_block_tile
,
do_block_tile
);
decltype
(
pt_reg_tensor
),
store_tile
(
dot_lds_write_window
,
dot_block_tile
);
decltype
(
pt_gemm
)>(
pt_reg_tensor
,
pt_gemm
);
gemm_1
(
dv_acc
,
pt_reg_tensor
,
dot_reg_tensor
);
store_tile
(
d_lds_write_window
,
d_block_tile
);
// STAGE 4, OGrad@V Gemm2
auto
do_reg_tensor
=
load_tile
(
do_lds_read_window
);
auto
d
=
load_tile
(
d_lds_read_window
);
block_sync_lds
();
auto
dpt_acc
=
SPGradTBlockTileType
{};
dpt_acc
=
gemm_2
(
do_reg_tensor
,
v_reg_tensor
);
HotLoopScheduler
::
template
GemmStagedScheduler
<
2
>();
__builtin_amdgcn_sched_barrier
(
0
);
// STAGE 5, P^T(PGrad^T - D)
// STAGE 5, P^T(PGrad^T - D)
auto
dst
=
SPGradTBlockTileType
{};
auto
dst
=
SPGradTBlockTileType
{};
constexpr
auto
dst_spans
=
decltype
(
dst
)
::
get_distributed_spans
();
constexpr
auto
dst_spans
=
decltype
(
dst
)
::
get_distributed_spans
();
...
@@ -732,6 +698,9 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR
...
@@ -732,6 +698,9 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR
}
}
// STAGE 6, SGrad^T@Q^T Gemm3
// STAGE 6, SGrad^T@Q^T Gemm3
auto
qt_reg_tensor
=
load_tile
(
qt_lds_read_window
);
block_sync_lds
();
const
auto
dst_gemm
=
cast_tile
<
GemmDataType
>
(
dst
);
const
auto
dst_gemm
=
cast_tile
<
GemmDataType
>
(
dst
);
Policy
::
template
SGradTFromGemm2CToGemm3A
<
Problem
,
Policy
::
template
SGradTFromGemm2CToGemm3A
<
Problem
,
...
@@ -747,11 +716,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR
...
@@ -747,11 +716,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR
auto
ds_reg_tensor
=
load_tile
(
ds_lds_read_window
);
auto
ds_reg_tensor
=
load_tile
(
ds_lds_read_window
);
auto
ds_reg_tensor_next
=
decltype
(
ds_reg_tensor
){};
auto
ds_reg_tensor_next
=
decltype
(
ds_reg_tensor
){};
move_tile_window
(
ds_lds_read_window
,
{
0
,
kK4
});
move_tile_window
(
ds_lds_read_window
,
{
0
,
kK4
});
q_reg_tensor
=
load_tile
(
q_lds_read_window
);
lse
=
load_tile
(
lse_lds_read_window
);
HotLoopScheduler
::
template
GemmStagedScheduler
<
3
>();
__builtin_amdgcn_sched_barrier
(
0
);
// STAGE7 SGrad@K^T Gemm4
// STAGE7 SGrad@K^T Gemm4
auto
dq_acc
=
QGradBlockTileType
{};
auto
dq_acc
=
QGradBlockTileType
{};
clear_tile
(
dq_acc
);
clear_tile
(
dq_acc
);
...
@@ -773,12 +738,6 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR
...
@@ -773,12 +738,6 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR
}
}
});
});
move_tile_window
(
ds_lds_read_window
,
{
0
,
-
kN0
});
move_tile_window
(
ds_lds_read_window
,
{
0
,
-
kN0
});
do_reg_tensor
=
load_tile
(
do_lds_read_window
);
d
=
load_tile
(
d_lds_read_window
);
HotLoopScheduler
::
template
GemmStagedScheduler
<
4
>();
// QGrad Scale
// QGrad Scale
if
constexpr
(
FmhaDropout
::
IsDropout
)
if
constexpr
(
FmhaDropout
::
IsDropout
)
{
{
...
@@ -802,234 +761,19 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR
...
@@ -802,234 +761,19 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR
i_total_loops
+=
1
;
i_total_loops
+=
1
;
seqlen_q_step
+=
kM0
;
seqlen_q_step
+=
kM0
;
}
}
__builtin_amdgcn_sched_barrier
(
0
);
// Tail
auto
st_acc
=
SPTBlockTileType
{};
// STAGE 1, Q@K Gemm0
st_acc
=
gemm_0
(
q_reg_tensor
,
k_reg_tensor
);
// STAGE 2, Scale, Add bias, Mask, Softmax, Dropout
if
constexpr
(
BiasEnum
==
BlockAttentionBiasEnum
::
ELEMENTWISE_BIAS
)
{
const
auto
bias_tile
=
load_tile
(
bias_dram_window
);
auto
bias_shuffle_tmp
=
make_static_distributed_tensor
<
BiasDataType
>
(
Policy
::
template
MakeShuffledBiasTileDistribution
<
Problem
>());
shuffle_tile
(
bias_shuffle_tmp
,
bias_tile
);
store_tile
(
biast_lds_shuffle_window
,
bias_shuffle_tmp
);
block_sync_lds
();
auto
biast_tile
=
load_tile
(
biast_lds_window
);
tile_elementwise_inout
(
[
&
](
auto
&
x
,
const
auto
&
y
)
{
x
=
scale
*
x
+
log2e_v
<
AccDataType
>
*
type_convert
<
AccDataType
>
(
y
);
},
st_acc
,
biast_tile
);
}
else
if
constexpr
(
BiasEnum
==
BlockAttentionBiasEnum
::
ALIBI
)
{
constexpr
auto
st_spans
=
decltype
(
st_acc
)
::
get_distributed_spans
();
sweep_tile_span
(
st_spans
[
number
<
0
>
{}],
[
&
](
auto
idx0
)
{
sweep_tile_span
(
st_spans
[
number
<
1
>
{}],
[
&
](
auto
idx1
)
{
const
auto
tile_idx
=
get_x_indices_from_distributed_indices
(
st_acc
.
get_tile_distribution
(),
make_tuple
(
idx0
,
idx1
));
const
auto
row
=
seqlen_q_step
+
tile_idx
.
at
(
number
<
0
>
{});
const
auto
col
=
k_origin
.
at
(
number
<
0
>
{})
+
tile_idx
.
at
(
number
<
1
>
{});
constexpr
auto
i_j_idx
=
make_tuple
(
idx0
,
idx1
);
st_acc
(
i_j_idx
)
*=
scale
;
position_encoding
.
update
(
st_acc
(
i_j_idx
),
row
,
col
);
});
});
}
if
constexpr
(
kPadSeqLenK
||
FmhaMask
::
IsMasking
)
{
bool
need_perpixel_check
=
mask
.
IsEdgeTile
(
seqlen_q_step
,
k_origin
.
at
(
number
<
0
>
{}),
number
<
kM0
>
{},
number
<
kN0
>
{});
if
(
need_perpixel_check
)
{
set_tile_if
(
st_acc
,
-
numeric
<
AccDataType
>::
infinity
(),
[
&
](
auto
tile_idx
)
{
const
auto
row
=
seqlen_q_step
+
tile_idx
.
at
(
number
<
0
>
{});
const
auto
col
=
k_origin
.
at
(
number
<
0
>
{})
+
tile_idx
.
at
(
number
<
1
>
{});
return
mask
.
IsOutOfBound
(
row
,
col
);
});
}
}
static
const
auto
get_validated_lse
=
[](
LSEDataType
raw_lse
)
{
if
constexpr
(
BiasEnum
==
BlockAttentionBiasEnum
::
ELEMENTWISE_BIAS
||
FmhaMask
::
IsMasking
)
{
return
raw_lse
==
-
numeric
<
LSEDataType
>::
infinity
()
?
type_convert
<
LSEDataType
>
(
0.
f
)
:
raw_lse
;
}
else
{
return
raw_lse
;
}
};
auto
pt
=
SPTBlockTileType
{};
constexpr
auto
pt_spans
=
decltype
(
pt
)
::
get_distributed_spans
();
sweep_tile_span
(
pt_spans
[
number
<
0
>
{}],
[
&
](
auto
idx0
)
{
constexpr
auto
i_idx
=
make_tuple
(
idx0
);
auto
row_lse
=
log2e_v
<
LSEDataType
>
*
get_validated_lse
(
lse
[
i_idx
]);
sweep_tile_span
(
pt_spans
[
number
<
1
>
{}],
[
&
](
auto
idx1
)
{
constexpr
auto
i_j_idx
=
make_tuple
(
idx0
,
idx1
);
if
constexpr
(
BiasEnum
==
BlockAttentionBiasEnum
::
ELEMENTWISE_BIAS
||
BiasEnum
==
BlockAttentionBiasEnum
::
ALIBI
)
{
pt
(
i_j_idx
)
=
exp2
(
st_acc
[
i_j_idx
]
-
row_lse
);
}
else
{
pt
(
i_j_idx
)
=
exp2
(
scale
*
st_acc
[
i_j_idx
]
-
row_lse
);
}
});
});
if
constexpr
(
FmhaDropout
::
IsDropout
)
{
dropout
.
template
Run
<
decltype
(
gemm_0
),
RandValOutputDataType
>(
seqlen_q_step
,
k_origin
.
at
(
number
<
0
>
{}),
pt
,
randval_dram_window
);
}
// STAGE 3, P^T@OGrad^T Gemm1
const
auto
pt_gemm
=
[
&
]()
{
if
constexpr
(
FmhaDropout
::
IsDropout
)
{
return
tile_elementwise_in
(
[](
const
auto
&
x
)
{
return
type_convert
<
GemmDataType
>
(
x
>
0.
f
?
x
:
0.
f
);
},
pt
);
}
else
{
return
cast_tile
<
GemmDataType
>
(
pt
);
}
}();
Policy
::
template
PTFromGemm0CToGemm1A
<
Problem
,
decltype
(
pt_reg_tensor
),
decltype
(
pt_gemm
)>(
pt_reg_tensor
,
pt_gemm
);
auto
dot_reg_tensor
=
load_tile
(
dot_lds_read_window
);
gemm_1
(
dv_acc
,
pt_reg_tensor
,
dot_reg_tensor
);
HotLoopScheduler
::
template
GemmStagedScheduler
<
1
>();
// STAGE 4, OGrad@V Gemm2
auto
dpt_acc
=
SPGradTBlockTileType
{};
auto
qt_reg_tensor
=
load_tile
(
qt_lds_read_window
);
dpt_acc
=
gemm_2
(
do_reg_tensor
,
v_reg_tensor
);
HotLoopScheduler
::
template
GemmStagedScheduler
<
2
>();
// STAGE 5, P^T(PGrad^T - D)
auto
dst
=
SPGradTBlockTileType
{};
constexpr
auto
dst_spans
=
decltype
(
dst
)
::
get_distributed_spans
();
sweep_tile_span
(
dst_spans
[
number
<
0
>
{}],
[
&
](
auto
idx0
)
{
constexpr
auto
i_idx
=
make_tuple
(
idx0
);
sweep_tile_span
(
dst_spans
[
number
<
1
>
{}],
[
&
](
auto
idx1
)
{
constexpr
auto
i_j_idx
=
make_tuple
(
idx0
,
idx1
);
bool
undrop_flag
=
pt
[
i_j_idx
]
>=
0
;
dst
(
i_j_idx
)
=
pt
[
i_j_idx
]
*
(
!
FmhaDropout
::
IsDropout
||
undrop_flag
?
(
dpt_acc
[
i_j_idx
]
-
d
[
i_idx
])
:
d
[
i_idx
]);
});
});
if
constexpr
(
kHasBiasGrad
)
{
const
auto
dbiast
=
[
&
]()
{
if
constexpr
(
FmhaDropout
::
IsDropout
)
{
return
tile_elementwise_in
(
[
&
rp_undrop
](
const
auto
&
x
)
{
return
type_convert
<
BiasGradDataType
>
(
x
*
rp_undrop
);
},
dst
);
}
else
{
return
cast_tile
<
BiasGradDataType
>
(
dst
);
}
}();
store_tile
(
biast_lds_shuffle_window
,
dbiast
);
block_sync_lds
();
auto
dbiast_tile
=
load_tile
(
dbiast_lds_shuffle_window
);
auto
dbiast_shuffle_tmp
=
make_static_distributed_tensor
<
BiasGradDataType
>
(
Policy
::
template
MakeBiasTileDistribution
<
Problem
>());
shuffle_tile
(
dbiast_shuffle_tmp
,
dbiast_tile
);
store_tile
(
dbias_dram_window
,
dbiast_shuffle_tmp
);
}
// STAGE 6, SGrad^T@Q^T Gemm3
const
auto
dst_gemm
=
cast_tile
<
GemmDataType
>
(
dst
);
Policy
::
template
SGradTFromGemm2CToGemm3A
<
Problem
,
decltype
(
dst_reg_tensor
),
decltype
(
dst_gemm
)>(
dst_reg_tensor
,
dst_gemm
);
gemm_3
(
dk_acc
,
dst_reg_tensor
,
qt_reg_tensor
);
store_tile
(
ds_lds_window
,
dst_gemm
);
block_sync_lds
();
auto
ds_reg_tensor
=
load_tile
(
ds_lds_read_window
);
auto
ds_reg_tensor_next
=
decltype
(
ds_reg_tensor
){};
move_tile_window
(
ds_lds_read_window
,
{
0
,
kK4
});
HotLoopScheduler
::
template
GemmStagedScheduler
<
3
>();
// STAGE 7, SGrad@K^T Gemm4
auto
dq_acc
=
QGradBlockTileType
{};
clear_tile
(
dq_acc
);
static_for
<
0
,
k4_loops
,
1
>
{}([
&
](
auto
i_k4
)
{
if
constexpr
(
i_k4
<
k4_loops
-
1
)
{
ds_reg_tensor_next
=
load_tile
(
ds_lds_read_window
);
move_tile_window
(
ds_lds_read_window
,
{
0
,
kK4
});
}
auto
kt_reg_tensor_slice
=
get_slice_tile
(
kt_reg_tensor
,
sequence
<
0
,
i_k4
*
kK4
>
{},
sequence
<
kQKHeaddim
,
(
i_k4
+
1
)
*
kK4
>
{});
gemm_4
(
dq_acc
,
ds_reg_tensor
,
kt_reg_tensor_slice
);
if
constexpr
(
i_k4
<
k4_loops
-
1
)
{
ds_reg_tensor
.
get_thread_buffer
()
=
ds_reg_tensor_next
.
get_thread_buffer
();
}
});
HotLoopScheduler
::
template
GemmStagedScheduler
<
4
>();
// Results Scale
// Results Scale
if
constexpr
(
FmhaDropout
::
IsDropout
)
if
constexpr
(
FmhaDropout
::
IsDropout
)
{
{
tile_elementwise_inout
([
&
scale_rp_undrop
](
auto
&
x
)
{
x
=
x
*
scale_rp_undrop
;
},
dq_acc
);
tile_elementwise_inout
([
&
scale_rp_undrop
](
auto
&
x
)
{
x
=
x
*
scale_rp_undrop
;
},
tile_elementwise_inout
([
&
scale_rp_undrop
](
auto
&
x
)
{
x
=
x
*
scale_rp_undrop
;
},
dk_acc
);
dk_acc
);
tile_elementwise_inout
([
&
rp_undrop
](
auto
&
x
)
{
x
=
x
*
rp_undrop
;
},
dv_acc
);
tile_elementwise_inout
([
&
rp_undrop
](
auto
&
x
)
{
x
=
x
*
rp_undrop
;
},
dv_acc
);
}
}
else
else
{
{
tile_elementwise_inout
([
&
raw_scale
](
auto
&
x
)
{
x
=
x
*
raw_scale
;
},
dq_acc
);
tile_elementwise_inout
([
&
raw_scale
](
auto
&
x
)
{
x
=
x
*
raw_scale
;
},
dk_acc
);
tile_elementwise_inout
([
&
raw_scale
](
auto
&
x
)
{
x
=
x
*
raw_scale
;
},
dk_acc
);
}
}
if
constexpr
(
kIsDeterministic
)
{
store_tile
(
dq_dram_window
,
dq_acc
);
}
else
{
update_tile
(
dq_dram_window
,
dq_acc
);
}
return
make_tuple
(
dk_acc
,
dv_acc
);
return
make_tuple
(
dk_acc
,
dv_acc
);
}
}
};
};
...
...
include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr_iglp.hpp
0 → 100644
View file @
3d5b0755
This diff is collapsed.
Click to expand it.
include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_enum.hpp
View file @
3d5b0755
...
@@ -8,7 +8,8 @@ namespace ck_tile {
...
@@ -8,7 +8,8 @@ namespace ck_tile {
// This class is used for codegen pattern matching
// This class is used for codegen pattern matching
enum
class
BlockFmhaBwdPipelineEnum
enum
class
BlockFmhaBwdPipelineEnum
{
{
KRKTRVR
=
0
,
KRKTRVR_IGLP
=
0
,
KRKTRVR
,
};
};
}
// namespace ck_tile
}
// namespace ck_tile
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment