Unverified Commit 645fe812 authored by Po Yen Chen's avatar Po Yen Chen Committed by GitHub
Browse files

[CK_TILE] Fix fMHA fwd MakeKargs() compilation errors (#1689)



* Fix mis-matched tuple<> elem types

* Rename MakeKargs() as MakeKargsImpl()

---------
Co-authored-by: default avatarQianfeng <qianfeng.zhang@amd.com>
parent c2bcbb13
...@@ -150,7 +150,7 @@ auto fmha_bwd_dq_dk_dv_create_kargs_and_grids(fmha_bwd_args args) ...@@ -150,7 +150,7 @@ auto fmha_bwd_dq_dk_dv_create_kargs_and_grids(fmha_bwd_args args)
// create group mode kernel arguments // create group mode kernel arguments
if constexpr(FmhaBwdDQDKDVKernel::kIsGroupMode) if constexpr(FmhaBwdDQDKDVKernel::kIsGroupMode)
{ {
return FmhaBwdDQDKDVKernel::MakeKargs(args.q_ptr, return FmhaBwdDQDKDVKernel::MakeKargsImpl(args.q_ptr,
args.k_ptr, args.k_ptr,
args.v_ptr, args.v_ptr,
args.bias_ptr, args.bias_ptr,
...@@ -200,7 +200,7 @@ auto fmha_bwd_dq_dk_dv_create_kargs_and_grids(fmha_bwd_args args) ...@@ -200,7 +200,7 @@ auto fmha_bwd_dq_dk_dv_create_kargs_and_grids(fmha_bwd_args args)
} }
else else
{ // create batch mode kernel arguments { // create batch mode kernel arguments
return FmhaBwdDQDKDVKernel::MakeKargs(args.q_ptr, return FmhaBwdDQDKDVKernel::MakeKargsImpl(args.q_ptr,
args.k_ptr, args.k_ptr,
args.v_ptr, args.v_ptr,
args.bias_ptr, args.bias_ptr,
......
...@@ -281,7 +281,7 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args) ...@@ -281,7 +281,7 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args)
// create group mode kernel arguments // create group mode kernel arguments
if constexpr(FmhaKernel::kIsGroupMode) if constexpr(FmhaKernel::kIsGroupMode)
{ {
return FmhaKernel::MakeKargs(args.q_ptr, return FmhaKernel::MakeKargsImpl(args.q_ptr,
args.k_ptr, args.k_ptr,
args.v_ptr, args.v_ptr,
args.bias_ptr, args.bias_ptr,
...@@ -320,7 +320,7 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args) ...@@ -320,7 +320,7 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args)
} }
else else
{ // create batch mode kernel arguments { // create batch mode kernel arguments
return FmhaKernel::MakeKargs(args.q_ptr, return FmhaKernel::MakeKargsImpl(args.q_ptr,
args.k_ptr, args.k_ptr,
args.v_ptr, args.v_ptr,
args.bias_ptr, args.bias_ptr,
......
...@@ -304,7 +304,7 @@ struct FmhaBwdDQDKDVKernel ...@@ -304,7 +304,7 @@ struct FmhaBwdDQDKDVKernel
template <bool Cond = !kIsGroupMode> template <bool Cond = !kIsGroupMode>
CK_TILE_HOST static constexpr std::enable_if_t<Cond, Kargs> CK_TILE_HOST static constexpr std::enable_if_t<Cond, Kargs>
MakeKargs(const void* q_ptr, MakeKargsImpl(const void* q_ptr,
const void* k_ptr, const void* k_ptr,
const void* v_ptr, const void* v_ptr,
const void* bias_ptr, const void* bias_ptr,
...@@ -470,7 +470,7 @@ struct FmhaBwdDQDKDVKernel ...@@ -470,7 +470,7 @@ struct FmhaBwdDQDKDVKernel
return kargs; return kargs;
} }
// std::variant can't take in a list initializer, overload for backward compatibility // std::variant<> can't take in a list initializer, overload for backward compatibility
template <bool Cond = !kIsGroupMode> template <bool Cond = !kIsGroupMode>
CK_TILE_HOST static constexpr std::enable_if_t<Cond, Kargs> CK_TILE_HOST static constexpr std::enable_if_t<Cond, Kargs>
MakeKargs(const void* q_ptr, MakeKargs(const void* q_ptr,
...@@ -531,7 +531,7 @@ struct FmhaBwdDQDKDVKernel ...@@ -531,7 +531,7 @@ struct FmhaBwdDQDKDVKernel
float p_drop, float p_drop,
const std::tuple<uint64_t, uint64_t>& drop_seed_offset) const std::tuple<uint64_t, uint64_t>& drop_seed_offset)
{ {
return MakeKargs( return MakeKargsImpl(
q_ptr, q_ptr,
k_ptr, k_ptr,
v_ptr, v_ptr,
...@@ -591,7 +591,7 @@ struct FmhaBwdDQDKDVKernel ...@@ -591,7 +591,7 @@ struct FmhaBwdDQDKDVKernel
std::make_pair(std::get<0>(drop_seed_offset), std::get<1>(drop_seed_offset))); std::make_pair(std::get<0>(drop_seed_offset), std::get<1>(drop_seed_offset)));
} }
// std::variant can't take in a list initializer, overload for backward compatibility // std::variant<> can't take in a list initializer, overload for backward compatibility
template <bool Cond = !kIsGroupMode> template <bool Cond = !kIsGroupMode>
CK_TILE_HOST static constexpr std::enable_if_t<Cond, Kargs> CK_TILE_HOST static constexpr std::enable_if_t<Cond, Kargs>
MakeKargs(const void* q_ptr, MakeKargs(const void* q_ptr,
...@@ -650,9 +650,9 @@ struct FmhaBwdDQDKDVKernel ...@@ -650,9 +650,9 @@ struct FmhaBwdDQDKDVKernel
ck_tile::index_t window_size_right, ck_tile::index_t window_size_right,
ck_tile::index_t mask_type, ck_tile::index_t mask_type,
float p_drop, float p_drop,
const std::tuple<void*, void*>& drop_seed_offset) const std::tuple<const void*, const void*>& drop_seed_offset)
{ {
return MakeKargs( return MakeKargsImpl(
q_ptr, q_ptr,
k_ptr, k_ptr,
v_ptr, v_ptr,
...@@ -714,7 +714,7 @@ struct FmhaBwdDQDKDVKernel ...@@ -714,7 +714,7 @@ struct FmhaBwdDQDKDVKernel
template <bool Cond = kIsGroupMode> template <bool Cond = kIsGroupMode>
CK_TILE_HOST static constexpr std::enable_if_t<Cond, Kargs> CK_TILE_HOST static constexpr std::enable_if_t<Cond, Kargs>
MakeKargs(const void* q_ptr, MakeKargsImpl(const void* q_ptr,
const void* k_ptr, const void* k_ptr,
const void* v_ptr, const void* v_ptr,
const void* bias_ptr, const void* bias_ptr,
...@@ -858,7 +858,7 @@ struct FmhaBwdDQDKDVKernel ...@@ -858,7 +858,7 @@ struct FmhaBwdDQDKDVKernel
return kargs; return kargs;
} }
// std::variant can't take in a list initializer, overload for backward compatibility // std::variant<> can't take in a list initializer, overload for backward compatibility
template <bool Cond = kIsGroupMode> template <bool Cond = kIsGroupMode>
CK_TILE_HOST static constexpr std::enable_if_t<Cond, Kargs> CK_TILE_HOST static constexpr std::enable_if_t<Cond, Kargs>
MakeKargs(const void* q_ptr, MakeKargs(const void* q_ptr,
...@@ -909,7 +909,7 @@ struct FmhaBwdDQDKDVKernel ...@@ -909,7 +909,7 @@ struct FmhaBwdDQDKDVKernel
float p_drop, float p_drop,
const std::tuple<uint64_t, uint64_t>& drop_seed_offset) const std::tuple<uint64_t, uint64_t>& drop_seed_offset)
{ {
return MakeKargs( return MakeKargsImpl(
q_ptr, q_ptr,
k_ptr, k_ptr,
v_ptr, v_ptr,
...@@ -959,7 +959,7 @@ struct FmhaBwdDQDKDVKernel ...@@ -959,7 +959,7 @@ struct FmhaBwdDQDKDVKernel
std::make_pair(std::get<0>(drop_seed_offset), std::get<1>(drop_seed_offset))); std::make_pair(std::get<0>(drop_seed_offset), std::get<1>(drop_seed_offset)));
} }
// std::variant can't take in a list initializer, overload for backward compatibility // std::variant<> can't take in a list initializer, overload for backward compatibility
template <bool Cond = kIsGroupMode> template <bool Cond = kIsGroupMode>
CK_TILE_HOST static constexpr std::enable_if_t<Cond, Kargs> CK_TILE_HOST static constexpr std::enable_if_t<Cond, Kargs>
MakeKargs(const void* q_ptr, MakeKargs(const void* q_ptr,
...@@ -1008,9 +1008,9 @@ struct FmhaBwdDQDKDVKernel ...@@ -1008,9 +1008,9 @@ struct FmhaBwdDQDKDVKernel
ck_tile::index_t window_size_right, ck_tile::index_t window_size_right,
ck_tile::index_t mask_type, ck_tile::index_t mask_type,
float p_drop, float p_drop,
const std::tuple<void*, void*>& drop_seed_offset) const std::tuple<const void*, const void*>& drop_seed_offset)
{ {
return MakeKargs( return MakeKargsImpl(
q_ptr, q_ptr,
k_ptr, k_ptr,
v_ptr, v_ptr,
......
...@@ -64,7 +64,7 @@ struct FmhaFwdKernel ...@@ -64,7 +64,7 @@ struct FmhaFwdKernel
template <> struct t2s<ck_tile::bf8_t> { static constexpr const char * name = "bf8"; }; template <> struct t2s<ck_tile::bf8_t> { static constexpr const char * name = "bf8"; };
// clang-format on // clang-format on
__host__ static std::string GetName() CK_TILE_HOST static std::string GetName()
{ {
// sync with generate.py // sync with generate.py
// clang-format off // clang-format off
...@@ -267,8 +267,8 @@ struct FmhaFwdKernel ...@@ -267,8 +267,8 @@ struct FmhaFwdKernel
using Kargs = std::conditional_t<kIsGroupMode, FmhaFwdGroupModeKargs, FmhaFwdBatchModeKargs>; using Kargs = std::conditional_t<kIsGroupMode, FmhaFwdGroupModeKargs, FmhaFwdBatchModeKargs>;
template <bool Cond = !kIsGroupMode> template <bool Cond = !kIsGroupMode>
__host__ static constexpr std::enable_if_t<Cond, Kargs> CK_TILE_HOST static constexpr std::enable_if_t<Cond, Kargs>
MakeKargs(const void* q_ptr, MakeKargsImpl(const void* q_ptr,
const void* k_ptr, const void* k_ptr,
const void* v_ptr, const void* v_ptr,
const void* bias_ptr, const void* bias_ptr,
...@@ -399,9 +399,9 @@ struct FmhaFwdKernel ...@@ -399,9 +399,9 @@ struct FmhaFwdKernel
return kargs; return kargs;
} }
// std::variant can't take in a list initializer, overload for backward compatibility // std::variant<> can't take in a list initializer, overload for backward compatibility
template <bool Cond = !kIsGroupMode> template <bool Cond = !kIsGroupMode>
__host__ static constexpr std::enable_if_t<Cond, Kargs> CK_TILE_HOST static constexpr std::enable_if_t<Cond, Kargs>
MakeKargs(const void* q_ptr, MakeKargs(const void* q_ptr,
const void* k_ptr, const void* k_ptr,
const void* v_ptr, const void* v_ptr,
...@@ -445,7 +445,8 @@ struct FmhaFwdKernel ...@@ -445,7 +445,8 @@ struct FmhaFwdKernel
bool s_randval, bool s_randval,
const std::tuple<uint64_t, uint64_t>& drop_seed_offset) const std::tuple<uint64_t, uint64_t>& drop_seed_offset)
{ {
MakeKargs(q_ptr, return MakeKargsImpl(
q_ptr,
k_ptr, k_ptr,
v_ptr, v_ptr,
bias_ptr, bias_ptr,
...@@ -489,9 +490,9 @@ struct FmhaFwdKernel ...@@ -489,9 +490,9 @@ struct FmhaFwdKernel
std::make_pair(std::get<0>(drop_seed_offset), std::get<1>(drop_seed_offset))); std::make_pair(std::get<0>(drop_seed_offset), std::get<1>(drop_seed_offset)));
} }
// std::variant can't take in a list initializer, overload for backward compatibility // std::variant<> can't take in a list initializer, overload for backward compatibility
template <bool Cond = !kIsGroupMode> template <bool Cond = !kIsGroupMode>
__host__ static constexpr std::enable_if_t<Cond, Kargs> CK_TILE_HOST static constexpr std::enable_if_t<Cond, Kargs>
MakeKargs(const void* q_ptr, MakeKargs(const void* q_ptr,
const void* k_ptr, const void* k_ptr,
const void* v_ptr, const void* v_ptr,
...@@ -533,9 +534,10 @@ struct FmhaFwdKernel ...@@ -533,9 +534,10 @@ struct FmhaFwdKernel
ck_tile::index_t mask_type, ck_tile::index_t mask_type,
float p_drop, float p_drop,
bool s_randval, bool s_randval,
const std::tuple<void*, void*>& drop_seed_offset) const std::tuple<const void*, const void*>& drop_seed_offset)
{ {
MakeKargs(q_ptr, return MakeKargsImpl(
q_ptr,
k_ptr, k_ptr,
v_ptr, v_ptr,
bias_ptr, bias_ptr,
...@@ -580,8 +582,8 @@ struct FmhaFwdKernel ...@@ -580,8 +582,8 @@ struct FmhaFwdKernel
} }
template <bool Cond = kIsGroupMode> template <bool Cond = kIsGroupMode>
__host__ static constexpr std::enable_if_t<Cond, Kargs> CK_TILE_HOST static constexpr std::enable_if_t<Cond, Kargs>
MakeKargs(const void* q_ptr, MakeKargsImpl(const void* q_ptr,
const void* k_ptr, const void* k_ptr,
const void* v_ptr, const void* v_ptr,
const void* bias_ptr, const void* bias_ptr,
...@@ -702,9 +704,9 @@ struct FmhaFwdKernel ...@@ -702,9 +704,9 @@ struct FmhaFwdKernel
return kargs; return kargs;
} }
// std::variant can't take in a list initializer, overload for backward compatibility // std::variant<> can't take in a list initializer, overload for backward compatibility
template <bool Cond = kIsGroupMode> template <bool Cond = kIsGroupMode>
__host__ static constexpr std::enable_if_t<Cond, Kargs> CK_TILE_HOST static constexpr std::enable_if_t<Cond, Kargs>
MakeKargs(const void* q_ptr, MakeKargs(const void* q_ptr,
const void* k_ptr, const void* k_ptr,
const void* v_ptr, const void* v_ptr,
...@@ -742,7 +744,7 @@ struct FmhaFwdKernel ...@@ -742,7 +744,7 @@ struct FmhaFwdKernel
bool s_randval, bool s_randval,
const std::tuple<uint64_t, uint64_t>& drop_seed_offset) const std::tuple<uint64_t, uint64_t>& drop_seed_offset)
{ {
return MakeKargs( return MakeKargsImpl(
q_ptr, q_ptr,
k_ptr, k_ptr,
v_ptr, v_ptr,
...@@ -781,9 +783,9 @@ struct FmhaFwdKernel ...@@ -781,9 +783,9 @@ struct FmhaFwdKernel
std::make_pair(std::get<0>(drop_seed_offset), std::get<1>(drop_seed_offset))); std::make_pair(std::get<0>(drop_seed_offset), std::get<1>(drop_seed_offset)));
} }
// std::variant can't take in a list initializer, overload for backward compatibility // std::variant<> can't take in a list initializer, overload for backward compatibility
template <bool Cond = kIsGroupMode> template <bool Cond = kIsGroupMode>
__host__ static constexpr std::enable_if_t<Cond, Kargs> CK_TILE_HOST static constexpr std::enable_if_t<Cond, Kargs>
MakeKargs(const void* q_ptr, MakeKargs(const void* q_ptr,
const void* k_ptr, const void* k_ptr,
const void* v_ptr, const void* v_ptr,
...@@ -819,9 +821,9 @@ struct FmhaFwdKernel ...@@ -819,9 +821,9 @@ struct FmhaFwdKernel
ck_tile::index_t mask_type, ck_tile::index_t mask_type,
float p_drop, float p_drop,
bool s_randval, bool s_randval,
const std::tuple<void*, void*>& drop_seed_offset) const std::tuple<const void*, const void*>& drop_seed_offset)
{ {
return MakeKargs( return MakeKargsImpl(
q_ptr, q_ptr,
k_ptr, k_ptr,
v_ptr, v_ptr,
...@@ -860,7 +862,7 @@ struct FmhaFwdKernel ...@@ -860,7 +862,7 @@ struct FmhaFwdKernel
std::make_pair(std::get<0>(drop_seed_offset), std::get<1>(drop_seed_offset))); std::make_pair(std::get<0>(drop_seed_offset), std::get<1>(drop_seed_offset)));
} }
__host__ static constexpr auto GridSize(ck_tile::index_t batch_size_, CK_TILE_HOST static constexpr auto GridSize(ck_tile::index_t batch_size_,
ck_tile::index_t nhead_, ck_tile::index_t nhead_,
ck_tile::index_t seqlen_q_, ck_tile::index_t seqlen_q_,
ck_tile::index_t hdim_v_) ck_tile::index_t hdim_v_)
...@@ -868,7 +870,7 @@ struct FmhaFwdKernel ...@@ -868,7 +870,7 @@ struct FmhaFwdKernel
return TilePartitioner::GridSize(batch_size_, nhead_, seqlen_q_, hdim_v_); return TilePartitioner::GridSize(batch_size_, nhead_, seqlen_q_, hdim_v_);
} }
__host__ static constexpr auto BlockSize() { return dim3(kBlockSize); } CK_TILE_HOST static constexpr auto BlockSize() { return dim3(kBlockSize); }
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize() CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize()
{ {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment