Unverified Commit 645fe812 authored by Po Yen Chen's avatar Po Yen Chen Committed by GitHub
Browse files

[CK_TILE] Fix fMHA fwd MakeKargs() compilation errors (#1689)



* Fix mis-matched tuple<> elem types

* Rename MakeKargs() as MakeKargsImpl()

---------
Co-authored-by: default avatarQianfeng <qianfeng.zhang@amd.com>
parent c2bcbb13
...@@ -150,113 +150,113 @@ auto fmha_bwd_dq_dk_dv_create_kargs_and_grids(fmha_bwd_args args) ...@@ -150,113 +150,113 @@ auto fmha_bwd_dq_dk_dv_create_kargs_and_grids(fmha_bwd_args args)
// create group mode kernel arguments // create group mode kernel arguments
if constexpr(FmhaBwdDQDKDVKernel::kIsGroupMode) if constexpr(FmhaBwdDQDKDVKernel::kIsGroupMode)
{ {
return FmhaBwdDQDKDVKernel::MakeKargs(args.q_ptr, return FmhaBwdDQDKDVKernel::MakeKargsImpl(args.q_ptr,
args.k_ptr, args.k_ptr,
args.v_ptr, args.v_ptr,
args.bias_ptr, args.bias_ptr,
args.lse_ptr, args.lse_ptr,
args.do_ptr, args.do_ptr,
args.d_ptr, args.d_ptr,
args.rand_val_ptr, args.rand_val_ptr,
args.dk_ptr, args.dk_ptr,
args.dv_ptr, args.dv_ptr,
args.dbias_ptr, args.dbias_ptr,
args.dq_acc_ptr, args.dq_acc_ptr,
args.seqstart_q_ptr, args.seqstart_q_ptr,
args.seqstart_k_ptr, args.seqstart_k_ptr,
args.seqlen_k_ptr, args.seqlen_k_ptr,
args.hdim_q, args.hdim_q,
args.hdim_v, args.hdim_v,
args.nhead_q, args.nhead_q,
args.nhead_q / args.nhead_k, args.nhead_q / args.nhead_k,
args.scale, args.scale,
args.stride_q, args.stride_q,
args.stride_k, args.stride_k,
args.stride_v, args.stride_v,
args.stride_bias, args.stride_bias,
args.stride_randval, args.stride_randval,
args.stride_do, args.stride_do,
args.stride_dq_acc, args.stride_dq_acc,
args.stride_dk, args.stride_dk,
args.stride_dv, args.stride_dv,
args.stride_dbias, args.stride_dbias,
args.nhead_stride_q, args.nhead_stride_q,
args.nhead_stride_k, args.nhead_stride_k,
args.nhead_stride_v, args.nhead_stride_v,
args.nhead_stride_bias, args.nhead_stride_bias,
args.nhead_stride_randval, args.nhead_stride_randval,
args.nhead_stride_do, args.nhead_stride_do,
args.nhead_stride_lsed, args.nhead_stride_lsed,
args.nhead_stride_dq_acc, args.nhead_stride_dq_acc,
args.nhead_stride_dk, args.nhead_stride_dk,
args.nhead_stride_dv, args.nhead_stride_dv,
args.nhead_stride_dbias, args.nhead_stride_dbias,
args.split_stride_dq_acc, args.split_stride_dq_acc,
args.window_size_left, args.window_size_left,
args.window_size_right, args.window_size_right,
args.mask_type, args.mask_type,
args.p_drop, args.p_drop,
args.drop_seed_offset); args.drop_seed_offset);
} }
else else
{ // create batch mode kernel arguments { // create batch mode kernel arguments
return FmhaBwdDQDKDVKernel::MakeKargs(args.q_ptr, return FmhaBwdDQDKDVKernel::MakeKargsImpl(args.q_ptr,
args.k_ptr, args.k_ptr,
args.v_ptr, args.v_ptr,
args.bias_ptr, args.bias_ptr,
args.lse_ptr, args.lse_ptr,
args.do_ptr, args.do_ptr,
args.d_ptr, args.d_ptr,
args.rand_val_ptr, args.rand_val_ptr,
args.dk_ptr, args.dk_ptr,
args.dv_ptr, args.dv_ptr,
args.dbias_ptr, args.dbias_ptr,
args.dq_acc_ptr, args.dq_acc_ptr,
args.seqlen_q, args.seqlen_q,
args.seqlen_k, args.seqlen_k,
args.hdim_q, args.hdim_q,
args.hdim_v, args.hdim_v,
args.nhead_q, args.nhead_q,
args.nhead_q / args.nhead_k, args.nhead_q / args.nhead_k,
args.scale, args.scale,
args.stride_q, args.stride_q,
args.stride_k, args.stride_k,
args.stride_v, args.stride_v,
args.stride_bias, args.stride_bias,
args.stride_randval, args.stride_randval,
args.stride_do, args.stride_do,
args.stride_dq_acc, args.stride_dq_acc,
args.stride_dk, args.stride_dk,
args.stride_dv, args.stride_dv,
args.stride_dbias, args.stride_dbias,
args.nhead_stride_q, args.nhead_stride_q,
args.nhead_stride_k, args.nhead_stride_k,
args.nhead_stride_v, args.nhead_stride_v,
args.nhead_stride_bias, args.nhead_stride_bias,
args.nhead_stride_randval, args.nhead_stride_randval,
args.nhead_stride_do, args.nhead_stride_do,
args.nhead_stride_lsed, args.nhead_stride_lsed,
args.nhead_stride_dq_acc, args.nhead_stride_dq_acc,
args.nhead_stride_dk, args.nhead_stride_dk,
args.nhead_stride_dv, args.nhead_stride_dv,
args.nhead_stride_dbias, args.nhead_stride_dbias,
args.batch_stride_q, args.batch_stride_q,
args.batch_stride_k, args.batch_stride_k,
args.batch_stride_v, args.batch_stride_v,
args.batch_stride_bias, args.batch_stride_bias,
args.batch_stride_randval, args.batch_stride_randval,
args.batch_stride_do, args.batch_stride_do,
args.batch_stride_lsed, args.batch_stride_lsed,
args.batch_stride_dq_acc, args.batch_stride_dq_acc,
args.batch_stride_dk, args.batch_stride_dk,
args.batch_stride_dv, args.batch_stride_dv,
args.batch_stride_dbias, args.batch_stride_dbias,
args.split_stride_dq_acc, args.split_stride_dq_acc,
args.window_size_left, args.window_size_left,
args.window_size_right, args.window_size_right,
args.mask_type, args.mask_type,
args.p_drop, args.p_drop,
args.drop_seed_offset); args.drop_seed_offset);
} }
}(); }();
......
...@@ -281,87 +281,87 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args) ...@@ -281,87 +281,87 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args)
// create group mode kernel arguments // create group mode kernel arguments
if constexpr(FmhaKernel::kIsGroupMode) if constexpr(FmhaKernel::kIsGroupMode)
{ {
return FmhaKernel::MakeKargs(args.q_ptr, return FmhaKernel::MakeKargsImpl(args.q_ptr,
args.k_ptr, args.k_ptr,
args.v_ptr, args.v_ptr,
args.bias_ptr, args.bias_ptr,
args.rand_val_ptr, args.rand_val_ptr,
args.lse_ptr, args.lse_ptr,
args.o_ptr, args.o_ptr,
args.seqstart_q_ptr, args.seqstart_q_ptr,
args.seqstart_k_ptr, args.seqstart_k_ptr,
args.seqlen_k_ptr, args.seqlen_k_ptr,
args.hdim_q, args.hdim_q,
args.hdim_v, args.hdim_v,
args.nhead_q, args.nhead_q,
args.nhead_q / args.nhead_k, args.nhead_q / args.nhead_k,
args.scale_s, args.scale_s,
args.scale_p, args.scale_p,
args.scale_o, args.scale_o,
args.stride_q, args.stride_q,
args.stride_k, args.stride_k,
args.stride_v, args.stride_v,
args.stride_bias, args.stride_bias,
args.stride_randval, args.stride_randval,
args.stride_o, args.stride_o,
args.nhead_stride_q, args.nhead_stride_q,
args.nhead_stride_k, args.nhead_stride_k,
args.nhead_stride_v, args.nhead_stride_v,
args.nhead_stride_bias, args.nhead_stride_bias,
args.nhead_stride_randval, args.nhead_stride_randval,
args.nhead_stride_lse, args.nhead_stride_lse,
args.nhead_stride_o, args.nhead_stride_o,
args.window_size_left, args.window_size_left,
args.window_size_right, args.window_size_right,
args.mask_type, args.mask_type,
args.p_drop, args.p_drop,
args.s_randval, args.s_randval,
args.drop_seed_offset); args.drop_seed_offset);
} }
else else
{ // create batch mode kernel arguments { // create batch mode kernel arguments
return FmhaKernel::MakeKargs(args.q_ptr, return FmhaKernel::MakeKargsImpl(args.q_ptr,
args.k_ptr, args.k_ptr,
args.v_ptr, args.v_ptr,
args.bias_ptr, args.bias_ptr,
args.rand_val_ptr, args.rand_val_ptr,
args.lse_ptr, args.lse_ptr,
args.o_ptr, args.o_ptr,
args.seqlen_q, args.seqlen_q,
args.seqlen_k, args.seqlen_k,
args.hdim_q, args.hdim_q,
args.hdim_v, args.hdim_v,
args.nhead_q, args.nhead_q,
args.nhead_q / args.nhead_k, args.nhead_q / args.nhead_k,
args.scale_s, args.scale_s,
args.scale_p, args.scale_p,
args.scale_o, args.scale_o,
args.stride_q, args.stride_q,
args.stride_k, args.stride_k,
args.stride_v, args.stride_v,
args.stride_bias, args.stride_bias,
args.stride_randval, args.stride_randval,
args.stride_o, args.stride_o,
args.nhead_stride_q, args.nhead_stride_q,
args.nhead_stride_k, args.nhead_stride_k,
args.nhead_stride_v, args.nhead_stride_v,
args.nhead_stride_bias, args.nhead_stride_bias,
args.nhead_stride_randval, args.nhead_stride_randval,
args.nhead_stride_lse, args.nhead_stride_lse,
args.nhead_stride_o, args.nhead_stride_o,
args.batch_stride_q, args.batch_stride_q,
args.batch_stride_k, args.batch_stride_k,
args.batch_stride_v, args.batch_stride_v,
args.batch_stride_bias, args.batch_stride_bias,
args.batch_stride_randval, args.batch_stride_randval,
args.batch_stride_lse, args.batch_stride_lse,
args.batch_stride_o, args.batch_stride_o,
args.window_size_left, args.window_size_left,
args.window_size_right, args.window_size_right,
args.mask_type, args.mask_type,
args.p_drop, args.p_drop,
args.s_randval, args.s_randval,
args.drop_seed_offset); args.drop_seed_offset);
} }
}(); }();
......
...@@ -304,64 +304,64 @@ struct FmhaBwdDQDKDVKernel ...@@ -304,64 +304,64 @@ struct FmhaBwdDQDKDVKernel
template <bool Cond = !kIsGroupMode> template <bool Cond = !kIsGroupMode>
CK_TILE_HOST static constexpr std::enable_if_t<Cond, Kargs> CK_TILE_HOST static constexpr std::enable_if_t<Cond, Kargs>
MakeKargs(const void* q_ptr, MakeKargsImpl(const void* q_ptr,
const void* k_ptr, const void* k_ptr,
const void* v_ptr, const void* v_ptr,
const void* bias_ptr, const void* bias_ptr,
const void* lse_ptr, const void* lse_ptr,
const void* do_ptr, const void* do_ptr,
const void* d_ptr, const void* d_ptr,
void* rand_val_ptr, void* rand_val_ptr,
void* dk_ptr, void* dk_ptr,
void* dv_ptr, void* dv_ptr,
void* dbias_ptr, void* dbias_ptr,
void* dq_acc_ptr, void* dq_acc_ptr,
ck_tile::index_t seqlen_q, ck_tile::index_t seqlen_q,
ck_tile::index_t seqlen_k, ck_tile::index_t seqlen_k,
ck_tile::index_t hdim_q, ck_tile::index_t hdim_q,
ck_tile::index_t hdim_v, ck_tile::index_t hdim_v,
ck_tile::index_t num_head_q, ck_tile::index_t num_head_q,
ck_tile::index_t nhead_ratio_qk, ck_tile::index_t nhead_ratio_qk,
float scale, float scale,
ck_tile::index_t stride_q, ck_tile::index_t stride_q,
ck_tile::index_t stride_k, ck_tile::index_t stride_k,
ck_tile::index_t stride_v, ck_tile::index_t stride_v,
ck_tile::index_t stride_bias, ck_tile::index_t stride_bias,
ck_tile::index_t stride_randval, ck_tile::index_t stride_randval,
ck_tile::index_t stride_do, ck_tile::index_t stride_do,
ck_tile::index_t stride_dq_acc, ck_tile::index_t stride_dq_acc,
ck_tile::index_t stride_dk, ck_tile::index_t stride_dk,
ck_tile::index_t stride_dv, ck_tile::index_t stride_dv,
ck_tile::index_t stride_dbias, ck_tile::index_t stride_dbias,
ck_tile::index_t nhead_stride_q, ck_tile::index_t nhead_stride_q,
ck_tile::index_t nhead_stride_k, ck_tile::index_t nhead_stride_k,
ck_tile::index_t nhead_stride_v, ck_tile::index_t nhead_stride_v,
ck_tile::index_t nhead_stride_bias, ck_tile::index_t nhead_stride_bias,
ck_tile::index_t nhead_stride_randval, ck_tile::index_t nhead_stride_randval,
ck_tile::index_t nhead_stride_do, ck_tile::index_t nhead_stride_do,
ck_tile::index_t nhead_stride_lsed, ck_tile::index_t nhead_stride_lsed,
ck_tile::index_t nhead_stride_dq_acc, ck_tile::index_t nhead_stride_dq_acc,
ck_tile::index_t nhead_stride_dk, ck_tile::index_t nhead_stride_dk,
ck_tile::index_t nhead_stride_dv, ck_tile::index_t nhead_stride_dv,
ck_tile::index_t nhead_stride_dbias, ck_tile::index_t nhead_stride_dbias,
ck_tile::index_t batch_stride_q, ck_tile::index_t batch_stride_q,
ck_tile::index_t batch_stride_k, ck_tile::index_t batch_stride_k,
ck_tile::index_t batch_stride_v, ck_tile::index_t batch_stride_v,
ck_tile::index_t batch_stride_bias, ck_tile::index_t batch_stride_bias,
ck_tile::index_t batch_stride_randval, ck_tile::index_t batch_stride_randval,
ck_tile::index_t batch_stride_do, ck_tile::index_t batch_stride_do,
ck_tile::index_t batch_stride_lsed, ck_tile::index_t batch_stride_lsed,
ck_tile::index_t batch_stride_dq_acc, ck_tile::index_t batch_stride_dq_acc,
ck_tile::index_t batch_stride_dk, ck_tile::index_t batch_stride_dk,
ck_tile::index_t batch_stride_dv, ck_tile::index_t batch_stride_dv,
ck_tile::index_t batch_stride_dbias, ck_tile::index_t batch_stride_dbias,
ck_tile::index_t split_stride_dq_acc, ck_tile::index_t split_stride_dq_acc,
ck_tile::index_t window_size_left, ck_tile::index_t window_size_left,
ck_tile::index_t window_size_right, ck_tile::index_t window_size_right,
ck_tile::index_t mask_type, ck_tile::index_t mask_type,
float p_drop, float p_drop,
std::variant<std::pair<uint64_t, uint64_t>, std::pair<const void*, const void*>> std::variant<std::pair<uint64_t, uint64_t>, std::pair<const void*, const void*>>
drop_seed_offset) drop_seed_offset)
{ {
Kargs kargs{{q_ptr, Kargs kargs{{q_ptr,
k_ptr, k_ptr,
...@@ -470,7 +470,7 @@ struct FmhaBwdDQDKDVKernel ...@@ -470,7 +470,7 @@ struct FmhaBwdDQDKDVKernel
return kargs; return kargs;
} }
// std::variant can't take in a list initializer, overload for backward compatibility // std::variant<> can't take in a list initializer, overload for backward compatibility
template <bool Cond = !kIsGroupMode> template <bool Cond = !kIsGroupMode>
CK_TILE_HOST static constexpr std::enable_if_t<Cond, Kargs> CK_TILE_HOST static constexpr std::enable_if_t<Cond, Kargs>
MakeKargs(const void* q_ptr, MakeKargs(const void* q_ptr,
...@@ -531,7 +531,7 @@ struct FmhaBwdDQDKDVKernel ...@@ -531,7 +531,7 @@ struct FmhaBwdDQDKDVKernel
float p_drop, float p_drop,
const std::tuple<uint64_t, uint64_t>& drop_seed_offset) const std::tuple<uint64_t, uint64_t>& drop_seed_offset)
{ {
return MakeKargs( return MakeKargsImpl(
q_ptr, q_ptr,
k_ptr, k_ptr,
v_ptr, v_ptr,
...@@ -591,7 +591,7 @@ struct FmhaBwdDQDKDVKernel ...@@ -591,7 +591,7 @@ struct FmhaBwdDQDKDVKernel
std::make_pair(std::get<0>(drop_seed_offset), std::get<1>(drop_seed_offset))); std::make_pair(std::get<0>(drop_seed_offset), std::get<1>(drop_seed_offset)));
} }
// std::variant can't take in a list initializer, overload for backward compatibility // std::variant<> can't take in a list initializer, overload for backward compatibility
template <bool Cond = !kIsGroupMode> template <bool Cond = !kIsGroupMode>
CK_TILE_HOST static constexpr std::enable_if_t<Cond, Kargs> CK_TILE_HOST static constexpr std::enable_if_t<Cond, Kargs>
MakeKargs(const void* q_ptr, MakeKargs(const void* q_ptr,
...@@ -650,9 +650,9 @@ struct FmhaBwdDQDKDVKernel ...@@ -650,9 +650,9 @@ struct FmhaBwdDQDKDVKernel
ck_tile::index_t window_size_right, ck_tile::index_t window_size_right,
ck_tile::index_t mask_type, ck_tile::index_t mask_type,
float p_drop, float p_drop,
const std::tuple<void*, void*>& drop_seed_offset) const std::tuple<const void*, const void*>& drop_seed_offset)
{ {
return MakeKargs( return MakeKargsImpl(
q_ptr, q_ptr,
k_ptr, k_ptr,
v_ptr, v_ptr,
...@@ -714,54 +714,54 @@ struct FmhaBwdDQDKDVKernel ...@@ -714,54 +714,54 @@ struct FmhaBwdDQDKDVKernel
template <bool Cond = kIsGroupMode> template <bool Cond = kIsGroupMode>
CK_TILE_HOST static constexpr std::enable_if_t<Cond, Kargs> CK_TILE_HOST static constexpr std::enable_if_t<Cond, Kargs>
MakeKargs(const void* q_ptr, MakeKargsImpl(const void* q_ptr,
const void* k_ptr, const void* k_ptr,
const void* v_ptr, const void* v_ptr,
const void* bias_ptr, const void* bias_ptr,
const void* lse_ptr, const void* lse_ptr,
const void* do_ptr, const void* do_ptr,
const void* d_ptr, const void* d_ptr,
void* rand_val_ptr, void* rand_val_ptr,
void* dk_ptr, void* dk_ptr,
void* dv_ptr, void* dv_ptr,
void* dbias_ptr, void* dbias_ptr,
void* dq_acc_ptr, void* dq_acc_ptr,
const void* seqstart_q_ptr, const void* seqstart_q_ptr,
const void* seqstart_k_ptr, const void* seqstart_k_ptr,
const void* seqlen_k_ptr, const void* seqlen_k_ptr,
ck_tile::index_t hdim_q, ck_tile::index_t hdim_q,
ck_tile::index_t hdim_v, ck_tile::index_t hdim_v,
ck_tile::index_t num_head_q, ck_tile::index_t num_head_q,
ck_tile::index_t nhead_ratio_qk, ck_tile::index_t nhead_ratio_qk,
float scale, float scale,
ck_tile::index_t stride_q, ck_tile::index_t stride_q,
ck_tile::index_t stride_k, ck_tile::index_t stride_k,
ck_tile::index_t stride_v, ck_tile::index_t stride_v,
ck_tile::index_t stride_bias, ck_tile::index_t stride_bias,
ck_tile::index_t stride_randval, ck_tile::index_t stride_randval,
ck_tile::index_t stride_do, ck_tile::index_t stride_do,
ck_tile::index_t stride_dq_acc, ck_tile::index_t stride_dq_acc,
ck_tile::index_t stride_dk, ck_tile::index_t stride_dk,
ck_tile::index_t stride_dv, ck_tile::index_t stride_dv,
ck_tile::index_t stride_dbias, ck_tile::index_t stride_dbias,
ck_tile::index_t nhead_stride_q, ck_tile::index_t nhead_stride_q,
ck_tile::index_t nhead_stride_k, ck_tile::index_t nhead_stride_k,
ck_tile::index_t nhead_stride_v, ck_tile::index_t nhead_stride_v,
ck_tile::index_t nhead_stride_bias, ck_tile::index_t nhead_stride_bias,
ck_tile::index_t nhead_stride_randval, ck_tile::index_t nhead_stride_randval,
ck_tile::index_t nhead_stride_do, ck_tile::index_t nhead_stride_do,
ck_tile::index_t nhead_stride_lsed, ck_tile::index_t nhead_stride_lsed,
ck_tile::index_t nhead_stride_dq_acc, ck_tile::index_t nhead_stride_dq_acc,
ck_tile::index_t nhead_stride_dk, ck_tile::index_t nhead_stride_dk,
ck_tile::index_t nhead_stride_dv, ck_tile::index_t nhead_stride_dv,
ck_tile::index_t nhead_stride_dbias, ck_tile::index_t nhead_stride_dbias,
ck_tile::index_t split_stride_dq_acc, ck_tile::index_t split_stride_dq_acc,
ck_tile::index_t window_size_left, ck_tile::index_t window_size_left,
ck_tile::index_t window_size_right, ck_tile::index_t window_size_right,
ck_tile::index_t mask_type, ck_tile::index_t mask_type,
float p_drop, float p_drop,
std::variant<std::pair<uint64_t, uint64_t>, std::pair<const void*, const void*>> std::variant<std::pair<uint64_t, uint64_t>, std::pair<const void*, const void*>>
drop_seed_offset) drop_seed_offset)
{ {
Kargs kargs{{q_ptr, Kargs kargs{{q_ptr,
k_ptr, k_ptr,
...@@ -858,7 +858,7 @@ struct FmhaBwdDQDKDVKernel ...@@ -858,7 +858,7 @@ struct FmhaBwdDQDKDVKernel
return kargs; return kargs;
} }
// std::variant can't take in a list initializer, overload for backward compatibility // std::variant<> can't take in a list initializer, overload for backward compatibility
template <bool Cond = kIsGroupMode> template <bool Cond = kIsGroupMode>
CK_TILE_HOST static constexpr std::enable_if_t<Cond, Kargs> CK_TILE_HOST static constexpr std::enable_if_t<Cond, Kargs>
MakeKargs(const void* q_ptr, MakeKargs(const void* q_ptr,
...@@ -909,7 +909,7 @@ struct FmhaBwdDQDKDVKernel ...@@ -909,7 +909,7 @@ struct FmhaBwdDQDKDVKernel
float p_drop, float p_drop,
const std::tuple<uint64_t, uint64_t>& drop_seed_offset) const std::tuple<uint64_t, uint64_t>& drop_seed_offset)
{ {
return MakeKargs( return MakeKargsImpl(
q_ptr, q_ptr,
k_ptr, k_ptr,
v_ptr, v_ptr,
...@@ -959,7 +959,7 @@ struct FmhaBwdDQDKDVKernel ...@@ -959,7 +959,7 @@ struct FmhaBwdDQDKDVKernel
std::make_pair(std::get<0>(drop_seed_offset), std::get<1>(drop_seed_offset))); std::make_pair(std::get<0>(drop_seed_offset), std::get<1>(drop_seed_offset)));
} }
// std::variant can't take in a list initializer, overload for backward compatibility // std::variant<> can't take in a list initializer, overload for backward compatibility
template <bool Cond = kIsGroupMode> template <bool Cond = kIsGroupMode>
CK_TILE_HOST static constexpr std::enable_if_t<Cond, Kargs> CK_TILE_HOST static constexpr std::enable_if_t<Cond, Kargs>
MakeKargs(const void* q_ptr, MakeKargs(const void* q_ptr,
...@@ -1008,9 +1008,9 @@ struct FmhaBwdDQDKDVKernel ...@@ -1008,9 +1008,9 @@ struct FmhaBwdDQDKDVKernel
ck_tile::index_t window_size_right, ck_tile::index_t window_size_right,
ck_tile::index_t mask_type, ck_tile::index_t mask_type,
float p_drop, float p_drop,
const std::tuple<void*, void*>& drop_seed_offset) const std::tuple<const void*, const void*>& drop_seed_offset)
{ {
return MakeKargs( return MakeKargsImpl(
q_ptr, q_ptr,
k_ptr, k_ptr,
v_ptr, v_ptr,
......
...@@ -64,7 +64,7 @@ struct FmhaFwdKernel ...@@ -64,7 +64,7 @@ struct FmhaFwdKernel
template <> struct t2s<ck_tile::bf8_t> { static constexpr const char * name = "bf8"; }; template <> struct t2s<ck_tile::bf8_t> { static constexpr const char * name = "bf8"; };
// clang-format on // clang-format on
__host__ static std::string GetName() CK_TILE_HOST static std::string GetName()
{ {
// sync with generate.py // sync with generate.py
// clang-format off // clang-format off
...@@ -267,50 +267,50 @@ struct FmhaFwdKernel ...@@ -267,50 +267,50 @@ struct FmhaFwdKernel
using Kargs = std::conditional_t<kIsGroupMode, FmhaFwdGroupModeKargs, FmhaFwdBatchModeKargs>; using Kargs = std::conditional_t<kIsGroupMode, FmhaFwdGroupModeKargs, FmhaFwdBatchModeKargs>;
template <bool Cond = !kIsGroupMode> template <bool Cond = !kIsGroupMode>
__host__ static constexpr std::enable_if_t<Cond, Kargs> CK_TILE_HOST static constexpr std::enable_if_t<Cond, Kargs>
MakeKargs(const void* q_ptr, MakeKargsImpl(const void* q_ptr,
const void* k_ptr, const void* k_ptr,
const void* v_ptr, const void* v_ptr,
const void* bias_ptr, const void* bias_ptr,
void* rand_val_ptr, void* rand_val_ptr,
void* lse_ptr, void* lse_ptr,
void* o_ptr, void* o_ptr,
ck_tile::index_t seqlen_q, ck_tile::index_t seqlen_q,
ck_tile::index_t seqlen_k, ck_tile::index_t seqlen_k,
ck_tile::index_t hdim_q, ck_tile::index_t hdim_q,
ck_tile::index_t hdim_v, ck_tile::index_t hdim_v,
ck_tile::index_t num_head_q, ck_tile::index_t num_head_q,
ck_tile::index_t nhead_ratio_qk, ck_tile::index_t nhead_ratio_qk,
float scale_s, float scale_s,
float scale_p, float scale_p,
float scale_o, float scale_o,
ck_tile::index_t stride_q, ck_tile::index_t stride_q,
ck_tile::index_t stride_k, ck_tile::index_t stride_k,
ck_tile::index_t stride_v, ck_tile::index_t stride_v,
ck_tile::index_t stride_bias, ck_tile::index_t stride_bias,
ck_tile::index_t stride_randval, ck_tile::index_t stride_randval,
ck_tile::index_t stride_o, ck_tile::index_t stride_o,
ck_tile::index_t nhead_stride_q, ck_tile::index_t nhead_stride_q,
ck_tile::index_t nhead_stride_k, ck_tile::index_t nhead_stride_k,
ck_tile::index_t nhead_stride_v, ck_tile::index_t nhead_stride_v,
ck_tile::index_t nhead_stride_bias, ck_tile::index_t nhead_stride_bias,
ck_tile::index_t nhead_stride_randval, ck_tile::index_t nhead_stride_randval,
ck_tile::index_t nhead_stride_lse, ck_tile::index_t nhead_stride_lse,
ck_tile::index_t nhead_stride_o, ck_tile::index_t nhead_stride_o,
ck_tile::index_t batch_stride_q, ck_tile::index_t batch_stride_q,
ck_tile::index_t batch_stride_k, ck_tile::index_t batch_stride_k,
ck_tile::index_t batch_stride_v, ck_tile::index_t batch_stride_v,
ck_tile::index_t batch_stride_bias, ck_tile::index_t batch_stride_bias,
ck_tile::index_t batch_stride_randval, ck_tile::index_t batch_stride_randval,
ck_tile::index_t batch_stride_lse, ck_tile::index_t batch_stride_lse,
ck_tile::index_t batch_stride_o, ck_tile::index_t batch_stride_o,
ck_tile::index_t window_size_left, ck_tile::index_t window_size_left,
ck_tile::index_t window_size_right, ck_tile::index_t window_size_right,
ck_tile::index_t mask_type, ck_tile::index_t mask_type,
float p_drop, float p_drop,
bool s_randval, bool s_randval,
std::variant<std::pair<uint64_t, uint64_t>, std::pair<const void*, const void*>> std::variant<std::pair<uint64_t, uint64_t>, std::pair<const void*, const void*>>
drop_seed_offset) drop_seed_offset)
{ {
Kargs kargs{{q_ptr, Kargs kargs{{q_ptr,
k_ptr, k_ptr,
...@@ -399,9 +399,9 @@ struct FmhaFwdKernel ...@@ -399,9 +399,9 @@ struct FmhaFwdKernel
return kargs; return kargs;
} }
// std::variant can't take in a list initializer, overload for backward compatibility // std::variant<> can't take in a list initializer, overload for backward compatibility
template <bool Cond = !kIsGroupMode> template <bool Cond = !kIsGroupMode>
__host__ static constexpr std::enable_if_t<Cond, Kargs> CK_TILE_HOST static constexpr std::enable_if_t<Cond, Kargs>
MakeKargs(const void* q_ptr, MakeKargs(const void* q_ptr,
const void* k_ptr, const void* k_ptr,
const void* v_ptr, const void* v_ptr,
...@@ -445,53 +445,54 @@ struct FmhaFwdKernel ...@@ -445,53 +445,54 @@ struct FmhaFwdKernel
bool s_randval, bool s_randval,
const std::tuple<uint64_t, uint64_t>& drop_seed_offset) const std::tuple<uint64_t, uint64_t>& drop_seed_offset)
{ {
MakeKargs(q_ptr, return MakeKargsImpl(
k_ptr, q_ptr,
v_ptr, k_ptr,
bias_ptr, v_ptr,
rand_val_ptr, bias_ptr,
lse_ptr, rand_val_ptr,
o_ptr, lse_ptr,
seqlen_q, o_ptr,
seqlen_k, seqlen_q,
hdim_q, seqlen_k,
hdim_v, hdim_q,
num_head_q, hdim_v,
nhead_ratio_qk, num_head_q,
scale_s, nhead_ratio_qk,
scale_p, scale_s,
scale_o, scale_p,
stride_q, scale_o,
stride_k, stride_q,
stride_v, stride_k,
stride_bias, stride_v,
stride_randval, stride_bias,
stride_o, stride_randval,
nhead_stride_q, stride_o,
nhead_stride_k, nhead_stride_q,
nhead_stride_v, nhead_stride_k,
nhead_stride_bias, nhead_stride_v,
nhead_stride_randval, nhead_stride_bias,
nhead_stride_lse, nhead_stride_randval,
nhead_stride_o, nhead_stride_lse,
batch_stride_q, nhead_stride_o,
batch_stride_k, batch_stride_q,
batch_stride_v, batch_stride_k,
batch_stride_bias, batch_stride_v,
batch_stride_randval, batch_stride_bias,
batch_stride_lse, batch_stride_randval,
batch_stride_o, batch_stride_lse,
window_size_left, batch_stride_o,
window_size_right, window_size_left,
mask_type, window_size_right,
p_drop, mask_type,
s_randval, p_drop,
std::make_pair(std::get<0>(drop_seed_offset), std::get<1>(drop_seed_offset))); s_randval,
std::make_pair(std::get<0>(drop_seed_offset), std::get<1>(drop_seed_offset)));
} }
// std::variant can't take in a list initializer, overload for backward compatibility // std::variant<> can't take in a list initializer, overload for backward compatibility
template <bool Cond = !kIsGroupMode> template <bool Cond = !kIsGroupMode>
__host__ static constexpr std::enable_if_t<Cond, Kargs> CK_TILE_HOST static constexpr std::enable_if_t<Cond, Kargs>
MakeKargs(const void* q_ptr, MakeKargs(const void* q_ptr,
const void* k_ptr, const void* k_ptr,
const void* v_ptr, const void* v_ptr,
...@@ -533,91 +534,92 @@ struct FmhaFwdKernel ...@@ -533,91 +534,92 @@ struct FmhaFwdKernel
ck_tile::index_t mask_type, ck_tile::index_t mask_type,
float p_drop, float p_drop,
bool s_randval, bool s_randval,
const std::tuple<void*, void*>& drop_seed_offset) const std::tuple<const void*, const void*>& drop_seed_offset)
{ {
MakeKargs(q_ptr, return MakeKargsImpl(
k_ptr, q_ptr,
v_ptr, k_ptr,
bias_ptr, v_ptr,
rand_val_ptr, bias_ptr,
lse_ptr, rand_val_ptr,
o_ptr, lse_ptr,
seqlen_q, o_ptr,
seqlen_k, seqlen_q,
hdim_q, seqlen_k,
hdim_v, hdim_q,
num_head_q, hdim_v,
nhead_ratio_qk, num_head_q,
scale_s, nhead_ratio_qk,
scale_p, scale_s,
scale_o, scale_p,
stride_q, scale_o,
stride_k, stride_q,
stride_v, stride_k,
stride_bias, stride_v,
stride_randval, stride_bias,
stride_o, stride_randval,
nhead_stride_q, stride_o,
nhead_stride_k, nhead_stride_q,
nhead_stride_v, nhead_stride_k,
nhead_stride_bias, nhead_stride_v,
nhead_stride_randval, nhead_stride_bias,
nhead_stride_lse, nhead_stride_randval,
nhead_stride_o, nhead_stride_lse,
batch_stride_q, nhead_stride_o,
batch_stride_k, batch_stride_q,
batch_stride_v, batch_stride_k,
batch_stride_bias, batch_stride_v,
batch_stride_randval, batch_stride_bias,
batch_stride_lse, batch_stride_randval,
batch_stride_o, batch_stride_lse,
window_size_left, batch_stride_o,
window_size_right, window_size_left,
mask_type, window_size_right,
p_drop, mask_type,
s_randval, p_drop,
std::make_pair(std::get<0>(drop_seed_offset), std::get<1>(drop_seed_offset))); s_randval,
std::make_pair(std::get<0>(drop_seed_offset), std::get<1>(drop_seed_offset)));
} }
template <bool Cond = kIsGroupMode> template <bool Cond = kIsGroupMode>
__host__ static constexpr std::enable_if_t<Cond, Kargs> CK_TILE_HOST static constexpr std::enable_if_t<Cond, Kargs>
MakeKargs(const void* q_ptr, MakeKargsImpl(const void* q_ptr,
const void* k_ptr, const void* k_ptr,
const void* v_ptr, const void* v_ptr,
const void* bias_ptr, const void* bias_ptr,
void* rand_val_ptr, void* rand_val_ptr,
void* lse_ptr, void* lse_ptr,
void* o_ptr, void* o_ptr,
const void* seqstart_q_ptr, const void* seqstart_q_ptr,
const void* seqstart_k_ptr, const void* seqstart_k_ptr,
const void* seqlen_k_ptr, const void* seqlen_k_ptr,
ck_tile::index_t hdim_q, ck_tile::index_t hdim_q,
ck_tile::index_t hdim_v, ck_tile::index_t hdim_v,
ck_tile::index_t num_head_q, ck_tile::index_t num_head_q,
ck_tile::index_t nhead_ratio_qk, ck_tile::index_t nhead_ratio_qk,
float scale_s, float scale_s,
float scale_p, float scale_p,
float scale_o, float scale_o,
ck_tile::index_t stride_q, ck_tile::index_t stride_q,
ck_tile::index_t stride_k, ck_tile::index_t stride_k,
ck_tile::index_t stride_v, ck_tile::index_t stride_v,
ck_tile::index_t stride_bias, ck_tile::index_t stride_bias,
ck_tile::index_t stride_randval, ck_tile::index_t stride_randval,
ck_tile::index_t stride_o, ck_tile::index_t stride_o,
ck_tile::index_t nhead_stride_q, ck_tile::index_t nhead_stride_q,
ck_tile::index_t nhead_stride_k, ck_tile::index_t nhead_stride_k,
ck_tile::index_t nhead_stride_v, ck_tile::index_t nhead_stride_v,
ck_tile::index_t nhead_stride_bias, ck_tile::index_t nhead_stride_bias,
ck_tile::index_t nhead_stride_randval, ck_tile::index_t nhead_stride_randval,
ck_tile::index_t nhead_stride_lse, ck_tile::index_t nhead_stride_lse,
ck_tile::index_t nhead_stride_o, ck_tile::index_t nhead_stride_o,
ck_tile::index_t window_size_left, ck_tile::index_t window_size_left,
ck_tile::index_t window_size_right, ck_tile::index_t window_size_right,
ck_tile::index_t mask_type, ck_tile::index_t mask_type,
float p_drop, float p_drop,
bool s_randval, bool s_randval,
std::variant<std::pair<uint64_t, uint64_t>, std::pair<const void*, const void*>> std::variant<std::pair<uint64_t, uint64_t>, std::pair<const void*, const void*>>
drop_seed_offset) drop_seed_offset)
{ {
Kargs kargs{{q_ptr, Kargs kargs{{q_ptr,
k_ptr, k_ptr,
...@@ -702,9 +704,9 @@ struct FmhaFwdKernel ...@@ -702,9 +704,9 @@ struct FmhaFwdKernel
return kargs; return kargs;
} }
// std::variant can't take in a list initializer, overload for backward compatibility // std::variant<> can't take in a list initializer, overload for backward compatibility
template <bool Cond = kIsGroupMode> template <bool Cond = kIsGroupMode>
__host__ static constexpr std::enable_if_t<Cond, Kargs> CK_TILE_HOST static constexpr std::enable_if_t<Cond, Kargs>
MakeKargs(const void* q_ptr, MakeKargs(const void* q_ptr,
const void* k_ptr, const void* k_ptr,
const void* v_ptr, const void* v_ptr,
...@@ -742,7 +744,7 @@ struct FmhaFwdKernel ...@@ -742,7 +744,7 @@ struct FmhaFwdKernel
bool s_randval, bool s_randval,
const std::tuple<uint64_t, uint64_t>& drop_seed_offset) const std::tuple<uint64_t, uint64_t>& drop_seed_offset)
{ {
return MakeKargs( return MakeKargsImpl(
q_ptr, q_ptr,
k_ptr, k_ptr,
v_ptr, v_ptr,
...@@ -781,9 +783,9 @@ struct FmhaFwdKernel ...@@ -781,9 +783,9 @@ struct FmhaFwdKernel
std::make_pair(std::get<0>(drop_seed_offset), std::get<1>(drop_seed_offset))); std::make_pair(std::get<0>(drop_seed_offset), std::get<1>(drop_seed_offset)));
} }
// std::variant can't take in a list initializer, overload for backward compatibility // std::variant<> can't take in a list initializer, overload for backward compatibility
template <bool Cond = kIsGroupMode> template <bool Cond = kIsGroupMode>
__host__ static constexpr std::enable_if_t<Cond, Kargs> CK_TILE_HOST static constexpr std::enable_if_t<Cond, Kargs>
MakeKargs(const void* q_ptr, MakeKargs(const void* q_ptr,
const void* k_ptr, const void* k_ptr,
const void* v_ptr, const void* v_ptr,
...@@ -819,9 +821,9 @@ struct FmhaFwdKernel ...@@ -819,9 +821,9 @@ struct FmhaFwdKernel
ck_tile::index_t mask_type, ck_tile::index_t mask_type,
float p_drop, float p_drop,
bool s_randval, bool s_randval,
const std::tuple<void*, void*>& drop_seed_offset) const std::tuple<const void*, const void*>& drop_seed_offset)
{ {
return MakeKargs( return MakeKargsImpl(
q_ptr, q_ptr,
k_ptr, k_ptr,
v_ptr, v_ptr,
...@@ -860,15 +862,15 @@ struct FmhaFwdKernel ...@@ -860,15 +862,15 @@ struct FmhaFwdKernel
std::make_pair(std::get<0>(drop_seed_offset), std::get<1>(drop_seed_offset))); std::make_pair(std::get<0>(drop_seed_offset), std::get<1>(drop_seed_offset)));
} }
__host__ static constexpr auto GridSize(ck_tile::index_t batch_size_, CK_TILE_HOST static constexpr auto GridSize(ck_tile::index_t batch_size_,
ck_tile::index_t nhead_, ck_tile::index_t nhead_,
ck_tile::index_t seqlen_q_, ck_tile::index_t seqlen_q_,
ck_tile::index_t hdim_v_) ck_tile::index_t hdim_v_)
{ {
return TilePartitioner::GridSize(batch_size_, nhead_, seqlen_q_, hdim_v_); return TilePartitioner::GridSize(batch_size_, nhead_, seqlen_q_, hdim_v_);
} }
__host__ static constexpr auto BlockSize() { return dim3(kBlockSize); } CK_TILE_HOST static constexpr auto BlockSize() { return dim3(kBlockSize); }
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize() CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize()
{ {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment