Commit 39fc3d4b authored by danyao12's avatar danyao12
Browse files

fix group deterministic bugs

parent 8c967d76
...@@ -132,11 +132,6 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -132,11 +132,6 @@ bool run(const ck_tile::ArgParser& arg_parser)
ck_tile::index_t hdim_v = arg_parser.get_int("d_v"); ck_tile::index_t hdim_v = arg_parser.get_int("d_v");
if(hdim_v < 0) if(hdim_v < 0)
hdim_v = hdim_q; hdim_v = hdim_q;
if(hdim_q % 2 != 0 || hdim_v % 2 != 0)
{
std::cerr << "FMHA Bwd kernel currently only supports even headdim" << std::endl;
return false;
}
bool i_perm = arg_parser.get_bool("iperm"); // if true, will be batch * nhead * seqlen * hdim bool i_perm = arg_parser.get_bool("iperm"); // if true, will be batch * nhead * seqlen * hdim
bool o_perm = arg_parser.get_bool("operm"); // if false, will be batch * seqlen * nhead * hdim bool o_perm = arg_parser.get_bool("operm"); // if false, will be batch * seqlen * nhead * hdim
......
...@@ -297,7 +297,7 @@ auto fmha_bwd_convert_dq_create_kargs_and_grids(fmha_bwd_args args) ...@@ -297,7 +297,7 @@ auto fmha_bwd_convert_dq_create_kargs_and_grids(fmha_bwd_args args)
return FmhaBwdConvertQGradKernel::MakeKargs(args.dq_acc_ptr, return FmhaBwdConvertQGradKernel::MakeKargs(args.dq_acc_ptr,
args.dq_ptr, args.dq_ptr,
args.seqstart_q_ptr, args.seqstart_q_ptr,
args.seqlen_k_ptr, args.seqstart_k_ptr,
args.hdim_q, args.hdim_q,
args.stride_q, args.stride_q,
args.nhead_stride_q, args.nhead_stride_q,
......
...@@ -1375,7 +1375,7 @@ struct FmhaBwdConvertQGradKernel ...@@ -1375,7 +1375,7 @@ struct FmhaBwdConvertQGradKernel
FmhaBwdConvertQGradEmptyKargs<0>> FmhaBwdConvertQGradEmptyKargs<0>>
{ {
const int32_t* seqstart_q_ptr; const int32_t* seqstart_q_ptr;
const int32_t* seqlen_k_ptr; const int32_t* seqstart_k_ptr;
}; };
using Kargs = std::conditional_t<kIsGroupMode, using Kargs = std::conditional_t<kIsGroupMode,
...@@ -1411,7 +1411,7 @@ struct FmhaBwdConvertQGradKernel ...@@ -1411,7 +1411,7 @@ struct FmhaBwdConvertQGradKernel
MakeKargs(const void* dq_acc_ptr, MakeKargs(const void* dq_acc_ptr,
void* dq_ptr, void* dq_ptr,
const void* seqstart_q_ptr, const void* seqstart_q_ptr,
const void* seqlen_k_ptr, const void* seqstart_k_ptr,
ck_tile::index_t hdim_q, ck_tile::index_t hdim_q,
ck_tile::index_t stride_dq, ck_tile::index_t stride_dq,
ck_tile::index_t nhead_stride_dq, ck_tile::index_t nhead_stride_dq,
...@@ -1426,7 +1426,7 @@ struct FmhaBwdConvertQGradKernel ...@@ -1426,7 +1426,7 @@ struct FmhaBwdConvertQGradKernel
nhead_stride_dq}, nhead_stride_dq},
{}, {},
reinterpret_cast<const int32_t*>(seqstart_q_ptr), reinterpret_cast<const int32_t*>(seqstart_q_ptr),
reinterpret_cast<const int32_t*>(seqlen_k_ptr)}; reinterpret_cast<const int32_t*>(seqstart_k_ptr)};
if constexpr(kIsDeterministic) if constexpr(kIsDeterministic)
{ {
...@@ -1463,7 +1463,8 @@ struct FmhaBwdConvertQGradKernel ...@@ -1463,7 +1463,8 @@ struct FmhaBwdConvertQGradKernel
// get real # queries & # keys under group mode // get real # queries & # keys under group mode
const auto adjusted_seqstart_q_ptr = kargs.seqstart_q_ptr + i_batch; const auto adjusted_seqstart_q_ptr = kargs.seqstart_q_ptr + i_batch;
kargs.seqlen_q = adjusted_seqstart_q_ptr[1] - adjusted_seqstart_q_ptr[0]; kargs.seqlen_q = adjusted_seqstart_q_ptr[1] - adjusted_seqstart_q_ptr[0];
kargs.seqlen_k = kargs.seqlen_k_ptr[i_batch]; const auto adjusted_seqstart_k_ptr = kargs.seqstart_k_ptr + i_batch;
kargs.seqlen_k = adjusted_seqstart_k_ptr[1] - adjusted_seqstart_k_ptr[0];
// # of required blocks is different in each groups, terminate unnecessary blocks // # of required blocks is different in each groups, terminate unnecessary blocks
// earlier // earlier
if(kargs.seqlen_q <= i_m0) if(kargs.seqlen_q <= i_m0)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment