Commit a67bdd63 authored by danyao12's avatar danyao12
Browse files

simplify convert dq

parent 2ef396bb
......@@ -622,15 +622,6 @@ def get_bwd_dot_do_o_blobs() -> List[FmhaBwdOGradDotOKernel]:
FMHA_BWD_CONVERT_DQ_KERNEL_BODY="""
using fmha_dtype_{F_idx} = {F_dtype};
using fmha_block_tile_{F_idx} = ck_tile::sequence<{F_bm0}, {F_bn0}, {F_hdim}>;
using fmha_block_warps_{F_idx} = ck_tile::sequence<{F_rm}, {F_rn}, {F_rk}>;
using fmha_warp_tile_{F_idx} = ck_tile::sequence<{F_wm}, {F_wn}, {F_wk}>;
using fmha_bwd_convert_dq_shape_{F_idx} =
ck_tile::TileFmhaBwdConvertQGradShape<fmha_block_tile_{F_idx},
fmha_block_warps_{F_idx},
fmha_warp_tile_{F_idx}>;
using fmha_bwd_convert_dq_trait_{F_idx} =
ck_tile::TileFmhaBwdConvertQGradTraits<{F_spad}, {F_dpad}, {F_occupancy}>;
......@@ -638,10 +629,13 @@ using fmha_bwd_convert_dq_pipeline_problem_{F_idx} =
ck_tile::BlockFmhaBwdConvertQGradPipelineProblem<
typename FmhaBwdTypeConfig<fmha_dtype_{F_idx}>::AccDataType,
typename FmhaBwdTypeConfig<fmha_dtype_{F_idx}>::QGradDataType,
fmha_bwd_convert_dq_shape_{F_idx},
fmha_bwd_convert_dq_trait_{F_idx},
/* BlockSize = */ 256,
{F_bm0},
{F_bn0},
{F_hdim},
{F_mode},
{F_deterministic}>;
{F_deterministic},
fmha_bwd_convert_dq_trait_{F_idx}>;
using fmha_bwd_convert_dq_{F_idx} =
typename ck_tile::BlockFmhaBwdConvertQGrad<fmha_bwd_convert_dq_pipeline_problem_{F_idx}>;
......@@ -699,12 +693,6 @@ class FmhaBwdConvertQGradKernel:
F_dtype : str # data type
F_bm0 : int # tile size along q seqlen (block size)
F_bn0 : int # tile size along k seqlen
F_rm : int # number of warps along k seqlen (block warps) in gemm4
F_rn : int # number of warps along q seqlen (block warps) in gemm4
F_rk : int # number of warps along gemm-k (not used) in gemm4
F_wm : int # warp size along m in gemm4
F_wn : int # warp size along n in gemm4
F_wk : int # warp size along k in gemm4
F_spad : str # true/false
F_dpad : str #
F_mode : str # value from MODE_MAP
......@@ -720,12 +708,6 @@ class FmhaBwdConvertQGradKernel:
F_dtype = DTYPE_MAP[self.F_dtype],
F_bm0 = self.F_bm0,
F_bn0 = self.F_bn0,
F_rm = self.F_rm,
F_rn = self.F_rn,
F_rk = self.F_rk,
F_wm = self.F_wm,
F_wn = self.F_wn,
F_wk = self.F_wk,
F_spad = BOOL_MAP[self.F_spad],
F_dpad = BOOL_MAP[self.F_dpad],
F_mode = MODE_MAP[self.F_mode],
......@@ -741,8 +723,7 @@ class FmhaBwdConvertQGradKernel:
if n != '' : n = 'p' + n
return n
pn = pad_name()
n = f"fmha_bwd_convert_dq_d{self.F_hdim}_{self.F_dtype}_b{self.F_bm0}x{self.F_bn0}_r{self.F_rm}x{self.F_rn}x{self.F_rk}" +\
f"_w{self.F_wm}x{self.F_wn}x{self.F_wk}_{self.F_mode}_o{self.F_occupancy}"
n = f"fmha_bwd_convert_dq_d{self.F_hdim}_{self.F_dtype}_b{self.F_bm0}x{self.F_bn0}_{self.F_mode}_o{self.F_occupancy}"
if pn != '' : n += f'_{pn}'
if self.F_deterministic == 't' : n += f'_deterministic'
return n
......@@ -769,7 +750,6 @@ def get_bwd_convert_dq_blobs() -> List[FmhaBwdConvertQGradKernel]:
if (mode == "group" and spad == "f"):
continue
k = FmhaBwdConvertQGradKernel(F_idx=0, F_hdim=hdim, F_dtype=dtype, F_bm0=64, F_bn0=tile.F_bn0,
F_rm=tile.F_rm2, F_rn=tile.F_rn2, F_rk=tile.F_rk2, F_wm=tile.F_wm0, F_wn=tile.F_wn0, F_wk=tile.F_wk0,
F_spad=spad, F_dpad=dpad, F_mode=mode, F_occupancy=get_occupancy(dtype, hdim), F_deterministic=deterministic)
gen.append(k)
......
......@@ -14,12 +14,12 @@ struct BlockFmhaBwdConvertQGrad
using AccDataType = remove_cvref_t<typename Problem::AccDataType>;
using QGradDataType = remove_cvref_t<typename Problem::QGradDataType>;
static constexpr index_t kM0 = Problem::Shape::kM0;
static constexpr index_t kN0 = Problem::Shape::kN0;
static constexpr index_t kM0 = Problem::kM0;
static constexpr index_t kN0 = Problem::kN0;
static constexpr index_t kBlockPerCu = Problem::kBlockPerCu;
static constexpr index_t kBlockSize = Problem::kBlockSize;
static constexpr index_t kQKHeaddim = Problem::Shape::kQKHeaddim;
static constexpr index_t kQKHeaddim = Problem::kQKHeaddim;
static constexpr bool kIsGroupMode = Problem::kIsGroupMode;
static constexpr bool kPadSeqLenQ = Problem::kPadSeqLenQ;
......
......@@ -561,8 +561,8 @@ struct BlockFmhaBwdPipelineDefaultPolicy
using AccDataType = remove_cvref_t<typename Problem::AccDataType>;
constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kMPerBlock = Problem::Shape::kM0;
constexpr index_t kKPerBlock = Problem::Shape::kQKHeaddim;
constexpr index_t kMPerBlock = Problem::kM0;
constexpr index_t kKPerBlock = Problem::kQKHeaddim;
constexpr index_t K1 = 16 / sizeof(AccDataType);
constexpr index_t K0 = kKPerBlock / K1;
......@@ -586,8 +586,8 @@ struct BlockFmhaBwdPipelineDefaultPolicy
using AccDataType = remove_cvref_t<typename Problem::AccDataType>;
constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kMPerBlock = Problem::Shape::kM0;
constexpr index_t kKPerBlock = Problem::Shape::kQKHeaddim;
constexpr index_t kMPerBlock = Problem::kM0;
constexpr index_t kKPerBlock = Problem::kQKHeaddim;
constexpr index_t K1 = 16 / sizeof(AccDataType);
constexpr index_t K0 = kKPerBlock / K1;
......
......@@ -93,24 +93,29 @@ struct BlockFmhaBwdOGradDotOPipelineProblem
template <typename AccDataType_,
typename QGradDataType_,
typename Shape_,
typename Traits_,
index_t kBlockSize_,
index_t kM0_,
index_t kN0_,
index_t kQKHeaddim_,
bool kIsGroupMode_,
bool kIsDeterministic_>
bool kIsDeterministic_,
typename Traits_>
struct BlockFmhaBwdConvertQGradPipelineProblem
{
using AccDataType = remove_cvref_t<AccDataType_>;
using QGradDataType = remove_cvref_t<QGradDataType_>;
using Shape = remove_cvref_t<Shape_>;
using Traits = remove_cvref_t<Traits_>;
static constexpr index_t kBlockSize = Shape::NumWarps * get_warp_size();
static_assert(0 < kBlockSize_ && kBlockSize_ % get_warp_size() == 0,
"kBlockSize should be divisible by get_warp_size()");
static constexpr index_t kBlockSize = kBlockSize_;
static constexpr index_t kM0 = kM0_;
static constexpr index_t kN0 = kN0_;
static constexpr index_t kQKHeaddim = kQKHeaddim_;
static constexpr bool kIsGroupMode = kIsGroupMode_;
static constexpr bool kIsDeterministic = kIsDeterministic_;
static_assert(0 < kBlockSize && kBlockSize % get_warp_size() == 0,
"kBlockSize should be divisible by get_warp_size()");
// attributes from traits
static constexpr bool kPadSeqLenQ = Traits::kPadSeqLenQ;
static constexpr bool kPadHeadDimQ = Traits::kPadHeadDimQ;
......
......@@ -92,20 +92,4 @@ struct TileFmhaBwdShape
// that need load V at once
};
template <typename BlockTile_, // sequence<...
typename BlockWarps_,
typename WarpTile_>
struct TileFmhaBwdConvertQGradShape
{
using BlockTile = remove_cvref_t<BlockTile_>;
using BlockWarps = remove_cvref_t<BlockWarps_>;
using WarpTile = remove_cvref_t<WarpTile_>;
static constexpr index_t NumWarps = reduce_on_sequence(BlockWarps{}, multiplies{}, number<1>{});
static constexpr index_t kM0 = BlockTile::at(number<0>{}); // tile size along q seqlen
static constexpr index_t kN0 = BlockTile::at(number<1>{}); // tile size along k seqlen
static constexpr index_t kQKHeaddim = BlockTile::at(number<2>{}); // Q & K headdim
};
} // namespace ck_tile
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment