Commit a67bdd63 authored by danyao12's avatar danyao12
Browse files

simplify convert dq

parent 2ef396bb
...@@ -622,15 +622,6 @@ def get_bwd_dot_do_o_blobs() -> List[FmhaBwdOGradDotOKernel]: ...@@ -622,15 +622,6 @@ def get_bwd_dot_do_o_blobs() -> List[FmhaBwdOGradDotOKernel]:
FMHA_BWD_CONVERT_DQ_KERNEL_BODY=""" FMHA_BWD_CONVERT_DQ_KERNEL_BODY="""
using fmha_dtype_{F_idx} = {F_dtype}; using fmha_dtype_{F_idx} = {F_dtype};
using fmha_block_tile_{F_idx} = ck_tile::sequence<{F_bm0}, {F_bn0}, {F_hdim}>;
using fmha_block_warps_{F_idx} = ck_tile::sequence<{F_rm}, {F_rn}, {F_rk}>;
using fmha_warp_tile_{F_idx} = ck_tile::sequence<{F_wm}, {F_wn}, {F_wk}>;
using fmha_bwd_convert_dq_shape_{F_idx} =
ck_tile::TileFmhaBwdConvertQGradShape<fmha_block_tile_{F_idx},
fmha_block_warps_{F_idx},
fmha_warp_tile_{F_idx}>;
using fmha_bwd_convert_dq_trait_{F_idx} = using fmha_bwd_convert_dq_trait_{F_idx} =
ck_tile::TileFmhaBwdConvertQGradTraits<{F_spad}, {F_dpad}, {F_occupancy}>; ck_tile::TileFmhaBwdConvertQGradTraits<{F_spad}, {F_dpad}, {F_occupancy}>;
...@@ -638,10 +629,13 @@ using fmha_bwd_convert_dq_pipeline_problem_{F_idx} = ...@@ -638,10 +629,13 @@ using fmha_bwd_convert_dq_pipeline_problem_{F_idx} =
ck_tile::BlockFmhaBwdConvertQGradPipelineProblem< ck_tile::BlockFmhaBwdConvertQGradPipelineProblem<
typename FmhaBwdTypeConfig<fmha_dtype_{F_idx}>::AccDataType, typename FmhaBwdTypeConfig<fmha_dtype_{F_idx}>::AccDataType,
typename FmhaBwdTypeConfig<fmha_dtype_{F_idx}>::QGradDataType, typename FmhaBwdTypeConfig<fmha_dtype_{F_idx}>::QGradDataType,
fmha_bwd_convert_dq_shape_{F_idx}, /* BlockSize = */ 256,
fmha_bwd_convert_dq_trait_{F_idx}, {F_bm0},
{F_bn0},
{F_hdim},
{F_mode}, {F_mode},
{F_deterministic}>; {F_deterministic},
fmha_bwd_convert_dq_trait_{F_idx}>;
using fmha_bwd_convert_dq_{F_idx} = using fmha_bwd_convert_dq_{F_idx} =
typename ck_tile::BlockFmhaBwdConvertQGrad<fmha_bwd_convert_dq_pipeline_problem_{F_idx}>; typename ck_tile::BlockFmhaBwdConvertQGrad<fmha_bwd_convert_dq_pipeline_problem_{F_idx}>;
...@@ -699,12 +693,6 @@ class FmhaBwdConvertQGradKernel: ...@@ -699,12 +693,6 @@ class FmhaBwdConvertQGradKernel:
F_dtype : str # data type F_dtype : str # data type
F_bm0 : int # tile size along q seqlen (block size) F_bm0 : int # tile size along q seqlen (block size)
F_bn0 : int # tile size along k seqlen F_bn0 : int # tile size along k seqlen
F_rm : int # number of warps along k seqlen (block warps) in gemm4
F_rn : int # number of warps along q seqlen (block warps) in gemm4
F_rk : int # number of warps along gemm-k (not used) in gemm4
F_wm : int # warp size along m in gemm4
F_wn : int # warp size along n in gemm4
F_wk : int # warp size along k in gemm4
F_spad : str # true/false F_spad : str # true/false
F_dpad : str # F_dpad : str #
F_mode : str # value from MODE_MAP F_mode : str # value from MODE_MAP
...@@ -720,12 +708,6 @@ class FmhaBwdConvertQGradKernel: ...@@ -720,12 +708,6 @@ class FmhaBwdConvertQGradKernel:
F_dtype = DTYPE_MAP[self.F_dtype], F_dtype = DTYPE_MAP[self.F_dtype],
F_bm0 = self.F_bm0, F_bm0 = self.F_bm0,
F_bn0 = self.F_bn0, F_bn0 = self.F_bn0,
F_rm = self.F_rm,
F_rn = self.F_rn,
F_rk = self.F_rk,
F_wm = self.F_wm,
F_wn = self.F_wn,
F_wk = self.F_wk,
F_spad = BOOL_MAP[self.F_spad], F_spad = BOOL_MAP[self.F_spad],
F_dpad = BOOL_MAP[self.F_dpad], F_dpad = BOOL_MAP[self.F_dpad],
F_mode = MODE_MAP[self.F_mode], F_mode = MODE_MAP[self.F_mode],
...@@ -741,8 +723,7 @@ class FmhaBwdConvertQGradKernel: ...@@ -741,8 +723,7 @@ class FmhaBwdConvertQGradKernel:
if n != '' : n = 'p' + n if n != '' : n = 'p' + n
return n return n
pn = pad_name() pn = pad_name()
n = f"fmha_bwd_convert_dq_d{self.F_hdim}_{self.F_dtype}_b{self.F_bm0}x{self.F_bn0}_r{self.F_rm}x{self.F_rn}x{self.F_rk}" +\ n = f"fmha_bwd_convert_dq_d{self.F_hdim}_{self.F_dtype}_b{self.F_bm0}x{self.F_bn0}_{self.F_mode}_o{self.F_occupancy}"
f"_w{self.F_wm}x{self.F_wn}x{self.F_wk}_{self.F_mode}_o{self.F_occupancy}"
if pn != '' : n += f'_{pn}' if pn != '' : n += f'_{pn}'
if self.F_deterministic == 't' : n += f'_deterministic' if self.F_deterministic == 't' : n += f'_deterministic'
return n return n
...@@ -769,7 +750,6 @@ def get_bwd_convert_dq_blobs() -> List[FmhaBwdConvertQGradKernel]: ...@@ -769,7 +750,6 @@ def get_bwd_convert_dq_blobs() -> List[FmhaBwdConvertQGradKernel]:
if (mode == "group" and spad == "f"): if (mode == "group" and spad == "f"):
continue continue
k = FmhaBwdConvertQGradKernel(F_idx=0, F_hdim=hdim, F_dtype=dtype, F_bm0=64, F_bn0=tile.F_bn0, k = FmhaBwdConvertQGradKernel(F_idx=0, F_hdim=hdim, F_dtype=dtype, F_bm0=64, F_bn0=tile.F_bn0,
F_rm=tile.F_rm2, F_rn=tile.F_rn2, F_rk=tile.F_rk2, F_wm=tile.F_wm0, F_wn=tile.F_wn0, F_wk=tile.F_wk0,
F_spad=spad, F_dpad=dpad, F_mode=mode, F_occupancy=get_occupancy(dtype, hdim), F_deterministic=deterministic) F_spad=spad, F_dpad=dpad, F_mode=mode, F_occupancy=get_occupancy(dtype, hdim), F_deterministic=deterministic)
gen.append(k) gen.append(k)
......
...@@ -14,12 +14,12 @@ struct BlockFmhaBwdConvertQGrad ...@@ -14,12 +14,12 @@ struct BlockFmhaBwdConvertQGrad
using AccDataType = remove_cvref_t<typename Problem::AccDataType>; using AccDataType = remove_cvref_t<typename Problem::AccDataType>;
using QGradDataType = remove_cvref_t<typename Problem::QGradDataType>; using QGradDataType = remove_cvref_t<typename Problem::QGradDataType>;
static constexpr index_t kM0 = Problem::Shape::kM0; static constexpr index_t kM0 = Problem::kM0;
static constexpr index_t kN0 = Problem::Shape::kN0; static constexpr index_t kN0 = Problem::kN0;
static constexpr index_t kBlockPerCu = Problem::kBlockPerCu; static constexpr index_t kBlockPerCu = Problem::kBlockPerCu;
static constexpr index_t kBlockSize = Problem::kBlockSize; static constexpr index_t kBlockSize = Problem::kBlockSize;
static constexpr index_t kQKHeaddim = Problem::Shape::kQKHeaddim; static constexpr index_t kQKHeaddim = Problem::kQKHeaddim;
static constexpr bool kIsGroupMode = Problem::kIsGroupMode; static constexpr bool kIsGroupMode = Problem::kIsGroupMode;
static constexpr bool kPadSeqLenQ = Problem::kPadSeqLenQ; static constexpr bool kPadSeqLenQ = Problem::kPadSeqLenQ;
......
...@@ -561,8 +561,8 @@ struct BlockFmhaBwdPipelineDefaultPolicy ...@@ -561,8 +561,8 @@ struct BlockFmhaBwdPipelineDefaultPolicy
using AccDataType = remove_cvref_t<typename Problem::AccDataType>; using AccDataType = remove_cvref_t<typename Problem::AccDataType>;
constexpr index_t kBlockSize = Problem::kBlockSize; constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kMPerBlock = Problem::Shape::kM0; constexpr index_t kMPerBlock = Problem::kM0;
constexpr index_t kKPerBlock = Problem::Shape::kQKHeaddim; constexpr index_t kKPerBlock = Problem::kQKHeaddim;
constexpr index_t K1 = 16 / sizeof(AccDataType); constexpr index_t K1 = 16 / sizeof(AccDataType);
constexpr index_t K0 = kKPerBlock / K1; constexpr index_t K0 = kKPerBlock / K1;
...@@ -586,8 +586,8 @@ struct BlockFmhaBwdPipelineDefaultPolicy ...@@ -586,8 +586,8 @@ struct BlockFmhaBwdPipelineDefaultPolicy
using AccDataType = remove_cvref_t<typename Problem::AccDataType>; using AccDataType = remove_cvref_t<typename Problem::AccDataType>;
constexpr index_t kBlockSize = Problem::kBlockSize; constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kMPerBlock = Problem::Shape::kM0; constexpr index_t kMPerBlock = Problem::kM0;
constexpr index_t kKPerBlock = Problem::Shape::kQKHeaddim; constexpr index_t kKPerBlock = Problem::kQKHeaddim;
constexpr index_t K1 = 16 / sizeof(AccDataType); constexpr index_t K1 = 16 / sizeof(AccDataType);
constexpr index_t K0 = kKPerBlock / K1; constexpr index_t K0 = kKPerBlock / K1;
......
...@@ -93,24 +93,29 @@ struct BlockFmhaBwdOGradDotOPipelineProblem ...@@ -93,24 +93,29 @@ struct BlockFmhaBwdOGradDotOPipelineProblem
template <typename AccDataType_, template <typename AccDataType_,
typename QGradDataType_, typename QGradDataType_,
typename Shape_, index_t kBlockSize_,
typename Traits_, index_t kM0_,
index_t kN0_,
index_t kQKHeaddim_,
bool kIsGroupMode_, bool kIsGroupMode_,
bool kIsDeterministic_> bool kIsDeterministic_,
typename Traits_>
struct BlockFmhaBwdConvertQGradPipelineProblem struct BlockFmhaBwdConvertQGradPipelineProblem
{ {
using AccDataType = remove_cvref_t<AccDataType_>; using AccDataType = remove_cvref_t<AccDataType_>;
using QGradDataType = remove_cvref_t<QGradDataType_>; using QGradDataType = remove_cvref_t<QGradDataType_>;
using Shape = remove_cvref_t<Shape_>;
using Traits = remove_cvref_t<Traits_>; using Traits = remove_cvref_t<Traits_>;
static constexpr index_t kBlockSize = Shape::NumWarps * get_warp_size(); static_assert(0 < kBlockSize_ && kBlockSize_ % get_warp_size() == 0,
"kBlockSize should be divisible by get_warp_size()");
static constexpr index_t kBlockSize = kBlockSize_;
static constexpr index_t kM0 = kM0_;
static constexpr index_t kN0 = kN0_;
static constexpr index_t kQKHeaddim = kQKHeaddim_;
static constexpr bool kIsGroupMode = kIsGroupMode_; static constexpr bool kIsGroupMode = kIsGroupMode_;
static constexpr bool kIsDeterministic = kIsDeterministic_; static constexpr bool kIsDeterministic = kIsDeterministic_;
static_assert(0 < kBlockSize && kBlockSize % get_warp_size() == 0,
"kBlockSize should be divisible by get_warp_size()");
// attributes from traits // attributes from traits
static constexpr bool kPadSeqLenQ = Traits::kPadSeqLenQ; static constexpr bool kPadSeqLenQ = Traits::kPadSeqLenQ;
static constexpr bool kPadHeadDimQ = Traits::kPadHeadDimQ; static constexpr bool kPadHeadDimQ = Traits::kPadHeadDimQ;
......
...@@ -92,20 +92,4 @@ struct TileFmhaBwdShape ...@@ -92,20 +92,4 @@ struct TileFmhaBwdShape
// that need load V at once // that need load V at once
}; };
template <typename BlockTile_, // sequence<...
typename BlockWarps_,
typename WarpTile_>
struct TileFmhaBwdConvertQGradShape
{
using BlockTile = remove_cvref_t<BlockTile_>;
using BlockWarps = remove_cvref_t<BlockWarps_>;
using WarpTile = remove_cvref_t<WarpTile_>;
static constexpr index_t NumWarps = reduce_on_sequence(BlockWarps{}, multiplies{}, number<1>{});
static constexpr index_t kM0 = BlockTile::at(number<0>{}); // tile size along q seqlen
static constexpr index_t kN0 = BlockTile::at(number<1>{}); // tile size along k seqlen
static constexpr index_t kQKHeaddim = BlockTile::at(number<2>{}); // Q & K headdim
};
} // namespace ck_tile } // namespace ck_tile
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment