Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
a67bdd63
"vscode:/vscode.git/clone" did not exist on "560919ab20d13f23057959c66932e676a36094ba"
Commit
a67bdd63
authored
Jul 18, 2024
by
danyao12
Browse files
simplify convert dq
parent
2ef396bb
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
27 additions
and
58 deletions
+27
-58
example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py
example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py
+7
-27
include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_convert_dq.hpp
...e/ck_tile/ops/fmha/pipeline/block_fmha_bwd_convert_dq.hpp
+3
-3
include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp
.../fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp
+4
-4
include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_problem.hpp
...ile/ops/fmha/pipeline/block_fmha_bwd_pipeline_problem.hpp
+13
-8
include/ck_tile/ops/fmha/pipeline/tile_fmha_shape.hpp
include/ck_tile/ops/fmha/pipeline/tile_fmha_shape.hpp
+0
-16
No files found.
example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py
View file @
a67bdd63
...
...
@@ -622,15 +622,6 @@ def get_bwd_dot_do_o_blobs() -> List[FmhaBwdOGradDotOKernel]:
FMHA_BWD_CONVERT_DQ_KERNEL_BODY
=
"""
using fmha_dtype_{F_idx} = {F_dtype};
using fmha_block_tile_{F_idx} = ck_tile::sequence<{F_bm0}, {F_bn0}, {F_hdim}>;
using fmha_block_warps_{F_idx} = ck_tile::sequence<{F_rm}, {F_rn}, {F_rk}>;
using fmha_warp_tile_{F_idx} = ck_tile::sequence<{F_wm}, {F_wn}, {F_wk}>;
using fmha_bwd_convert_dq_shape_{F_idx} =
ck_tile::TileFmhaBwdConvertQGradShape<fmha_block_tile_{F_idx},
fmha_block_warps_{F_idx},
fmha_warp_tile_{F_idx}>;
using fmha_bwd_convert_dq_trait_{F_idx} =
ck_tile::TileFmhaBwdConvertQGradTraits<{F_spad}, {F_dpad}, {F_occupancy}>;
...
...
@@ -638,10 +629,13 @@ using fmha_bwd_convert_dq_pipeline_problem_{F_idx} =
ck_tile::BlockFmhaBwdConvertQGradPipelineProblem<
typename FmhaBwdTypeConfig<fmha_dtype_{F_idx}>::AccDataType,
typename FmhaBwdTypeConfig<fmha_dtype_{F_idx}>::QGradDataType,
fmha_bwd_convert_dq_shape_{F_idx},
fmha_bwd_convert_dq_trait_{F_idx},
/* BlockSize = */ 256,
{F_bm0},
{F_bn0},
{F_hdim},
{F_mode},
{F_deterministic}>;
{F_deterministic},
fmha_bwd_convert_dq_trait_{F_idx}>;
using fmha_bwd_convert_dq_{F_idx} =
typename ck_tile::BlockFmhaBwdConvertQGrad<fmha_bwd_convert_dq_pipeline_problem_{F_idx}>;
...
...
@@ -699,12 +693,6 @@ class FmhaBwdConvertQGradKernel:
F_dtype
:
str
# data type
F_bm0
:
int
# tile size along q seqlen (block size)
F_bn0
:
int
# tile size along k seqlen
F_rm
:
int
# number of warps along k seqlen (block warps) in gemm4
F_rn
:
int
# number of warps along q seqlen (block warps) in gemm4
F_rk
:
int
# number of warps along gemm-k (not used) in gemm4
F_wm
:
int
# warp size along m in gemm4
F_wn
:
int
# warp size along n in gemm4
F_wk
:
int
# warp size along k in gemm4
F_spad
:
str
# true/false
F_dpad
:
str
#
F_mode
:
str
# value from MODE_MAP
...
...
@@ -720,12 +708,6 @@ class FmhaBwdConvertQGradKernel:
F_dtype
=
DTYPE_MAP
[
self
.
F_dtype
],
F_bm0
=
self
.
F_bm0
,
F_bn0
=
self
.
F_bn0
,
F_rm
=
self
.
F_rm
,
F_rn
=
self
.
F_rn
,
F_rk
=
self
.
F_rk
,
F_wm
=
self
.
F_wm
,
F_wn
=
self
.
F_wn
,
F_wk
=
self
.
F_wk
,
F_spad
=
BOOL_MAP
[
self
.
F_spad
],
F_dpad
=
BOOL_MAP
[
self
.
F_dpad
],
F_mode
=
MODE_MAP
[
self
.
F_mode
],
...
...
@@ -741,8 +723,7 @@ class FmhaBwdConvertQGradKernel:
if
n
!=
''
:
n
=
'p'
+
n
return
n
pn
=
pad_name
()
n
=
f
"fmha_bwd_convert_dq_d
{
self
.
F_hdim
}
_
{
self
.
F_dtype
}
_b
{
self
.
F_bm0
}
x
{
self
.
F_bn0
}
_r
{
self
.
F_rm
}
x
{
self
.
F_rn
}
x
{
self
.
F_rk
}
"
+
\
f
"_w
{
self
.
F_wm
}
x
{
self
.
F_wn
}
x
{
self
.
F_wk
}
_
{
self
.
F_mode
}
_o
{
self
.
F_occupancy
}
"
n
=
f
"fmha_bwd_convert_dq_d
{
self
.
F_hdim
}
_
{
self
.
F_dtype
}
_b
{
self
.
F_bm0
}
x
{
self
.
F_bn0
}
_
{
self
.
F_mode
}
_o
{
self
.
F_occupancy
}
"
if
pn
!=
''
:
n
+=
f
'_
{
pn
}
'
if
self
.
F_deterministic
==
't'
:
n
+=
f
'_deterministic'
return
n
...
...
@@ -769,7 +750,6 @@ def get_bwd_convert_dq_blobs() -> List[FmhaBwdConvertQGradKernel]:
if
(
mode
==
"group"
and
spad
==
"f"
):
continue
k
=
FmhaBwdConvertQGradKernel
(
F_idx
=
0
,
F_hdim
=
hdim
,
F_dtype
=
dtype
,
F_bm0
=
64
,
F_bn0
=
tile
.
F_bn0
,
F_rm
=
tile
.
F_rm2
,
F_rn
=
tile
.
F_rn2
,
F_rk
=
tile
.
F_rk2
,
F_wm
=
tile
.
F_wm0
,
F_wn
=
tile
.
F_wn0
,
F_wk
=
tile
.
F_wk0
,
F_spad
=
spad
,
F_dpad
=
dpad
,
F_mode
=
mode
,
F_occupancy
=
get_occupancy
(
dtype
,
hdim
),
F_deterministic
=
deterministic
)
gen
.
append
(
k
)
...
...
include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_convert_dq.hpp
View file @
a67bdd63
...
...
@@ -14,12 +14,12 @@ struct BlockFmhaBwdConvertQGrad
using
AccDataType
=
remove_cvref_t
<
typename
Problem
::
AccDataType
>
;
using
QGradDataType
=
remove_cvref_t
<
typename
Problem
::
QGradDataType
>
;
static
constexpr
index_t
kM0
=
Problem
::
Shape
::
kM0
;
static
constexpr
index_t
kN0
=
Problem
::
Shape
::
kN0
;
static
constexpr
index_t
kM0
=
Problem
::
kM0
;
static
constexpr
index_t
kN0
=
Problem
::
kN0
;
static
constexpr
index_t
kBlockPerCu
=
Problem
::
kBlockPerCu
;
static
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
static
constexpr
index_t
kQKHeaddim
=
Problem
::
Shape
::
kQKHeaddim
;
static
constexpr
index_t
kQKHeaddim
=
Problem
::
kQKHeaddim
;
static
constexpr
bool
kIsGroupMode
=
Problem
::
kIsGroupMode
;
static
constexpr
bool
kPadSeqLenQ
=
Problem
::
kPadSeqLenQ
;
...
...
include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp
View file @
a67bdd63
...
...
@@ -561,8 +561,8 @@ struct BlockFmhaBwdPipelineDefaultPolicy
using
AccDataType
=
remove_cvref_t
<
typename
Problem
::
AccDataType
>
;
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
kMPerBlock
=
Problem
::
Shape
::
kM0
;
constexpr
index_t
kKPerBlock
=
Problem
::
Shape
::
kQKHeaddim
;
constexpr
index_t
kMPerBlock
=
Problem
::
kM0
;
constexpr
index_t
kKPerBlock
=
Problem
::
kQKHeaddim
;
constexpr
index_t
K1
=
16
/
sizeof
(
AccDataType
);
constexpr
index_t
K0
=
kKPerBlock
/
K1
;
...
...
@@ -586,8 +586,8 @@ struct BlockFmhaBwdPipelineDefaultPolicy
using
AccDataType
=
remove_cvref_t
<
typename
Problem
::
AccDataType
>
;
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
kMPerBlock
=
Problem
::
Shape
::
kM0
;
constexpr
index_t
kKPerBlock
=
Problem
::
Shape
::
kQKHeaddim
;
constexpr
index_t
kMPerBlock
=
Problem
::
kM0
;
constexpr
index_t
kKPerBlock
=
Problem
::
kQKHeaddim
;
constexpr
index_t
K1
=
16
/
sizeof
(
AccDataType
);
constexpr
index_t
K0
=
kKPerBlock
/
K1
;
...
...
include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_problem.hpp
View file @
a67bdd63
...
...
@@ -93,24 +93,29 @@ struct BlockFmhaBwdOGradDotOPipelineProblem
template
<
typename
AccDataType_
,
typename
QGradDataType_
,
typename
Shape_
,
typename
Traits_
,
index_t
kBlockSize_
,
index_t
kM0_
,
index_t
kN0_
,
index_t
kQKHeaddim_
,
bool
kIsGroupMode_
,
bool
kIsDeterministic_
>
bool
kIsDeterministic_
,
typename
Traits_
>
struct
BlockFmhaBwdConvertQGradPipelineProblem
{
using
AccDataType
=
remove_cvref_t
<
AccDataType_
>
;
using
QGradDataType
=
remove_cvref_t
<
QGradDataType_
>
;
using
Shape
=
remove_cvref_t
<
Shape_
>
;
using
Traits
=
remove_cvref_t
<
Traits_
>
;
static
constexpr
index_t
kBlockSize
=
Shape
::
NumWarps
*
get_warp_size
();
static_assert
(
0
<
kBlockSize_
&&
kBlockSize_
%
get_warp_size
()
==
0
,
"kBlockSize should be divisible by get_warp_size()"
);
static
constexpr
index_t
kBlockSize
=
kBlockSize_
;
static
constexpr
index_t
kM0
=
kM0_
;
static
constexpr
index_t
kN0
=
kN0_
;
static
constexpr
index_t
kQKHeaddim
=
kQKHeaddim_
;
static
constexpr
bool
kIsGroupMode
=
kIsGroupMode_
;
static
constexpr
bool
kIsDeterministic
=
kIsDeterministic_
;
static_assert
(
0
<
kBlockSize
&&
kBlockSize
%
get_warp_size
()
==
0
,
"kBlockSize should be divisible by get_warp_size()"
);
// attributes from traits
static
constexpr
bool
kPadSeqLenQ
=
Traits
::
kPadSeqLenQ
;
static
constexpr
bool
kPadHeadDimQ
=
Traits
::
kPadHeadDimQ
;
...
...
include/ck_tile/ops/fmha/pipeline/tile_fmha_shape.hpp
View file @
a67bdd63
...
...
@@ -92,20 +92,4 @@ struct TileFmhaBwdShape
// that need load V at once
};
template
<
typename
BlockTile_
,
// sequence<...
typename
BlockWarps_
,
typename
WarpTile_
>
struct
TileFmhaBwdConvertQGradShape
{
using
BlockTile
=
remove_cvref_t
<
BlockTile_
>
;
using
BlockWarps
=
remove_cvref_t
<
BlockWarps_
>
;
using
WarpTile
=
remove_cvref_t
<
WarpTile_
>
;
static
constexpr
index_t
NumWarps
=
reduce_on_sequence
(
BlockWarps
{},
multiplies
{},
number
<
1
>
{});
static
constexpr
index_t
kM0
=
BlockTile
::
at
(
number
<
0
>
{});
// tile size along q seqlen
static
constexpr
index_t
kN0
=
BlockTile
::
at
(
number
<
1
>
{});
// tile size along k seqlen
static
constexpr
index_t
kQKHeaddim
=
BlockTile
::
at
(
number
<
2
>
{});
// Q & K headdim
};
}
// namespace ck_tile
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment