Unverified Commit 645fe812 authored by Po Yen Chen's avatar Po Yen Chen Committed by GitHub
Browse files

[CK_TILE] Fix fMHA fwd MakeKargs() compilation errors (#1689)



* Fix mis-matched tuple<> elem types

* Rename MakeKargs() as MakeKargsImpl()

---------
Co-authored-by: default avatarQianfeng <qianfeng.zhang@amd.com>
parent c2bcbb13
......@@ -150,7 +150,7 @@ auto fmha_bwd_dq_dk_dv_create_kargs_and_grids(fmha_bwd_args args)
// create group mode kernel arguments
if constexpr(FmhaBwdDQDKDVKernel::kIsGroupMode)
{
return FmhaBwdDQDKDVKernel::MakeKargs(args.q_ptr,
return FmhaBwdDQDKDVKernel::MakeKargsImpl(args.q_ptr,
args.k_ptr,
args.v_ptr,
args.bias_ptr,
......@@ -200,7 +200,7 @@ auto fmha_bwd_dq_dk_dv_create_kargs_and_grids(fmha_bwd_args args)
}
else
{ // create batch mode kernel arguments
return FmhaBwdDQDKDVKernel::MakeKargs(args.q_ptr,
return FmhaBwdDQDKDVKernel::MakeKargsImpl(args.q_ptr,
args.k_ptr,
args.v_ptr,
args.bias_ptr,
......
......@@ -281,7 +281,7 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args)
// create group mode kernel arguments
if constexpr(FmhaKernel::kIsGroupMode)
{
return FmhaKernel::MakeKargs(args.q_ptr,
return FmhaKernel::MakeKargsImpl(args.q_ptr,
args.k_ptr,
args.v_ptr,
args.bias_ptr,
......@@ -320,7 +320,7 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args)
}
else
{ // create batch mode kernel arguments
return FmhaKernel::MakeKargs(args.q_ptr,
return FmhaKernel::MakeKargsImpl(args.q_ptr,
args.k_ptr,
args.v_ptr,
args.bias_ptr,
......
......@@ -304,7 +304,7 @@ struct FmhaBwdDQDKDVKernel
template <bool Cond = !kIsGroupMode>
CK_TILE_HOST static constexpr std::enable_if_t<Cond, Kargs>
MakeKargs(const void* q_ptr,
MakeKargsImpl(const void* q_ptr,
const void* k_ptr,
const void* v_ptr,
const void* bias_ptr,
......@@ -470,7 +470,7 @@ struct FmhaBwdDQDKDVKernel
return kargs;
}
// std::variant can't take in a list initializer, overload for backward compatibility
// std::variant<> can't take in a list initializer, overload for backward compatibility
template <bool Cond = !kIsGroupMode>
CK_TILE_HOST static constexpr std::enable_if_t<Cond, Kargs>
MakeKargs(const void* q_ptr,
......@@ -531,7 +531,7 @@ struct FmhaBwdDQDKDVKernel
float p_drop,
const std::tuple<uint64_t, uint64_t>& drop_seed_offset)
{
return MakeKargs(
return MakeKargsImpl(
q_ptr,
k_ptr,
v_ptr,
......@@ -591,7 +591,7 @@ struct FmhaBwdDQDKDVKernel
std::make_pair(std::get<0>(drop_seed_offset), std::get<1>(drop_seed_offset)));
}
// std::variant can't take in a list initializer, overload for backward compatibility
// std::variant<> can't take in a list initializer, overload for backward compatibility
template <bool Cond = !kIsGroupMode>
CK_TILE_HOST static constexpr std::enable_if_t<Cond, Kargs>
MakeKargs(const void* q_ptr,
......@@ -650,9 +650,9 @@ struct FmhaBwdDQDKDVKernel
ck_tile::index_t window_size_right,
ck_tile::index_t mask_type,
float p_drop,
const std::tuple<void*, void*>& drop_seed_offset)
const std::tuple<const void*, const void*>& drop_seed_offset)
{
return MakeKargs(
return MakeKargsImpl(
q_ptr,
k_ptr,
v_ptr,
......@@ -714,7 +714,7 @@ struct FmhaBwdDQDKDVKernel
template <bool Cond = kIsGroupMode>
CK_TILE_HOST static constexpr std::enable_if_t<Cond, Kargs>
MakeKargs(const void* q_ptr,
MakeKargsImpl(const void* q_ptr,
const void* k_ptr,
const void* v_ptr,
const void* bias_ptr,
......@@ -858,7 +858,7 @@ struct FmhaBwdDQDKDVKernel
return kargs;
}
// std::variant can't take in a list initializer, overload for backward compatibility
// std::variant<> can't take in a list initializer, overload for backward compatibility
template <bool Cond = kIsGroupMode>
CK_TILE_HOST static constexpr std::enable_if_t<Cond, Kargs>
MakeKargs(const void* q_ptr,
......@@ -909,7 +909,7 @@ struct FmhaBwdDQDKDVKernel
float p_drop,
const std::tuple<uint64_t, uint64_t>& drop_seed_offset)
{
return MakeKargs(
return MakeKargsImpl(
q_ptr,
k_ptr,
v_ptr,
......@@ -959,7 +959,7 @@ struct FmhaBwdDQDKDVKernel
std::make_pair(std::get<0>(drop_seed_offset), std::get<1>(drop_seed_offset)));
}
// std::variant can't take in a list initializer, overload for backward compatibility
// std::variant<> can't take in a list initializer, overload for backward compatibility
template <bool Cond = kIsGroupMode>
CK_TILE_HOST static constexpr std::enable_if_t<Cond, Kargs>
MakeKargs(const void* q_ptr,
......@@ -1008,9 +1008,9 @@ struct FmhaBwdDQDKDVKernel
ck_tile::index_t window_size_right,
ck_tile::index_t mask_type,
float p_drop,
const std::tuple<void*, void*>& drop_seed_offset)
const std::tuple<const void*, const void*>& drop_seed_offset)
{
return MakeKargs(
return MakeKargsImpl(
q_ptr,
k_ptr,
v_ptr,
......
......@@ -64,7 +64,7 @@ struct FmhaFwdKernel
template <> struct t2s<ck_tile::bf8_t> { static constexpr const char * name = "bf8"; };
// clang-format on
__host__ static std::string GetName()
CK_TILE_HOST static std::string GetName()
{
// sync with generate.py
// clang-format off
......@@ -267,8 +267,8 @@ struct FmhaFwdKernel
using Kargs = std::conditional_t<kIsGroupMode, FmhaFwdGroupModeKargs, FmhaFwdBatchModeKargs>;
template <bool Cond = !kIsGroupMode>
__host__ static constexpr std::enable_if_t<Cond, Kargs>
MakeKargs(const void* q_ptr,
CK_TILE_HOST static constexpr std::enable_if_t<Cond, Kargs>
MakeKargsImpl(const void* q_ptr,
const void* k_ptr,
const void* v_ptr,
const void* bias_ptr,
......@@ -399,9 +399,9 @@ struct FmhaFwdKernel
return kargs;
}
// std::variant can't take in a list initializer, overload for backward compatibility
// std::variant<> can't take in a list initializer, overload for backward compatibility
template <bool Cond = !kIsGroupMode>
__host__ static constexpr std::enable_if_t<Cond, Kargs>
CK_TILE_HOST static constexpr std::enable_if_t<Cond, Kargs>
MakeKargs(const void* q_ptr,
const void* k_ptr,
const void* v_ptr,
......@@ -445,7 +445,8 @@ struct FmhaFwdKernel
bool s_randval,
const std::tuple<uint64_t, uint64_t>& drop_seed_offset)
{
MakeKargs(q_ptr,
return MakeKargsImpl(
q_ptr,
k_ptr,
v_ptr,
bias_ptr,
......@@ -489,9 +490,9 @@ struct FmhaFwdKernel
std::make_pair(std::get<0>(drop_seed_offset), std::get<1>(drop_seed_offset)));
}
// std::variant can't take in a list initializer, overload for backward compatibility
// std::variant<> can't take in a list initializer, overload for backward compatibility
template <bool Cond = !kIsGroupMode>
__host__ static constexpr std::enable_if_t<Cond, Kargs>
CK_TILE_HOST static constexpr std::enable_if_t<Cond, Kargs>
MakeKargs(const void* q_ptr,
const void* k_ptr,
const void* v_ptr,
......@@ -533,9 +534,10 @@ struct FmhaFwdKernel
ck_tile::index_t mask_type,
float p_drop,
bool s_randval,
const std::tuple<void*, void*>& drop_seed_offset)
const std::tuple<const void*, const void*>& drop_seed_offset)
{
MakeKargs(q_ptr,
return MakeKargsImpl(
q_ptr,
k_ptr,
v_ptr,
bias_ptr,
......@@ -580,8 +582,8 @@ struct FmhaFwdKernel
}
template <bool Cond = kIsGroupMode>
__host__ static constexpr std::enable_if_t<Cond, Kargs>
MakeKargs(const void* q_ptr,
CK_TILE_HOST static constexpr std::enable_if_t<Cond, Kargs>
MakeKargsImpl(const void* q_ptr,
const void* k_ptr,
const void* v_ptr,
const void* bias_ptr,
......@@ -702,9 +704,9 @@ struct FmhaFwdKernel
return kargs;
}
// std::variant can't take in a list initializer, overload for backward compatibility
// std::variant<> can't take in a list initializer, overload for backward compatibility
template <bool Cond = kIsGroupMode>
__host__ static constexpr std::enable_if_t<Cond, Kargs>
CK_TILE_HOST static constexpr std::enable_if_t<Cond, Kargs>
MakeKargs(const void* q_ptr,
const void* k_ptr,
const void* v_ptr,
......@@ -742,7 +744,7 @@ struct FmhaFwdKernel
bool s_randval,
const std::tuple<uint64_t, uint64_t>& drop_seed_offset)
{
return MakeKargs(
return MakeKargsImpl(
q_ptr,
k_ptr,
v_ptr,
......@@ -781,9 +783,9 @@ struct FmhaFwdKernel
std::make_pair(std::get<0>(drop_seed_offset), std::get<1>(drop_seed_offset)));
}
// std::variant can't take in a list initializer, overload for backward compatibility
// std::variant<> can't take in a list initializer, overload for backward compatibility
template <bool Cond = kIsGroupMode>
__host__ static constexpr std::enable_if_t<Cond, Kargs>
CK_TILE_HOST static constexpr std::enable_if_t<Cond, Kargs>
MakeKargs(const void* q_ptr,
const void* k_ptr,
const void* v_ptr,
......@@ -819,9 +821,9 @@ struct FmhaFwdKernel
ck_tile::index_t mask_type,
float p_drop,
bool s_randval,
const std::tuple<void*, void*>& drop_seed_offset)
const std::tuple<const void*, const void*>& drop_seed_offset)
{
return MakeKargs(
return MakeKargsImpl(
q_ptr,
k_ptr,
v_ptr,
......@@ -860,7 +862,7 @@ struct FmhaFwdKernel
std::make_pair(std::get<0>(drop_seed_offset), std::get<1>(drop_seed_offset)));
}
__host__ static constexpr auto GridSize(ck_tile::index_t batch_size_,
CK_TILE_HOST static constexpr auto GridSize(ck_tile::index_t batch_size_,
ck_tile::index_t nhead_,
ck_tile::index_t seqlen_q_,
ck_tile::index_t hdim_v_)
......@@ -868,7 +870,7 @@ struct FmhaFwdKernel
return TilePartitioner::GridSize(batch_size_, nhead_, seqlen_q_, hdim_v_);
}
__host__ static constexpr auto BlockSize() { return dim3(kBlockSize); }
CK_TILE_HOST static constexpr auto BlockSize() { return dim3(kBlockSize); }
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize()
{
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment