Commit 5d2a5a11 authored by danyao12's avatar danyao12
Browse files

more strides for fa integration

parent fd28454d
......@@ -496,6 +496,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
stride_randval,
stride_do,
stride_q, // stride_dq_acc
stride_q, // stride_dq
stride_dk,
stride_dv,
stride_dbias,
......@@ -508,6 +509,9 @@ bool run(const ck_tile::ArgParser& arg_parser)
nhead_stride_do,
nhead_stride_lsed,
nhead_stride_q, // nhead_stride_dq_acc
nhead_stride_q, // nhead_stride_dq
nhead_stride_k, // nhead_stride_dk
nhead_stride_v, // nhead_stride_dv
nhead_stride_dbias,
batch_stride_q,
batch_stride_k,
......@@ -518,6 +522,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
batch_stride_do,
batch_stride_lsed,
batch_stride_q, // batch_stride_dq_acc
batch_stride_q, // batch_stride_dq
batch_stride_dk,
batch_stride_dv,
batch_stride_dbias,
......
......@@ -99,6 +99,7 @@ struct fmha_bwd_args
ck_tile::index_t stride_randval;
ck_tile::index_t stride_do;
ck_tile::index_t stride_dq_acc;
ck_tile::index_t stride_dq;
ck_tile::index_t stride_dk;
ck_tile::index_t stride_dv;
ck_tile::index_t stride_dbias;
......@@ -111,6 +112,9 @@ struct fmha_bwd_args
ck_tile::index_t nhead_stride_do;
ck_tile::index_t nhead_stride_lsed;
ck_tile::index_t nhead_stride_dq_acc;
ck_tile::index_t nhead_stride_dq;
ck_tile::index_t nhead_stride_dk;
ck_tile::index_t nhead_stride_dv;
ck_tile::index_t nhead_stride_dbias;
ck_tile::index_t batch_stride_q;
ck_tile::index_t batch_stride_k;
......@@ -121,6 +125,7 @@ struct fmha_bwd_args
ck_tile::index_t batch_stride_do;
ck_tile::index_t batch_stride_lsed;
ck_tile::index_t batch_stride_dq_acc;
ck_tile::index_t batch_stride_dq;
ck_tile::index_t batch_stride_dk;
ck_tile::index_t batch_stride_dv;
ck_tile::index_t batch_stride_dbias;
......@@ -179,6 +184,8 @@ auto fmha_bwd_dq_dk_dv_create_kargs_and_grids(fmha_bwd_args args)
args.nhead_stride_do,
args.nhead_stride_lsed,
args.nhead_stride_dq_acc,
args.nhead_stride_dk,
args.nhead_stride_dv,
args.nhead_stride_dbias,
args.batch_stride_lsed,
args.split_stride_dq_acc,
......@@ -227,6 +234,8 @@ auto fmha_bwd_dq_dk_dv_create_kargs_and_grids(fmha_bwd_args args)
args.nhead_stride_do,
args.nhead_stride_lsed,
args.nhead_stride_dq_acc,
args.nhead_stride_dk,
args.nhead_stride_dv,
args.nhead_stride_dbias,
args.batch_stride_q,
args.batch_stride_k,
......@@ -307,9 +316,9 @@ auto fmha_bwd_convert_dq_create_kargs_and_grids(fmha_bwd_args args)
args.seqstart_q_ptr,
args.seqstart_k_ptr,
args.hdim_q,
args.stride_q,
args.stride_dq,
args.stride_dq_acc,
args.nhead_stride_q,
args.nhead_stride_dq,
args.nhead_stride_dq_acc,
args.split_stride_dq_acc);
}
......@@ -320,11 +329,11 @@ auto fmha_bwd_convert_dq_create_kargs_and_grids(fmha_bwd_args args)
args.seqlen_q,
args.seqlen_k,
args.hdim_q,
args.stride_q,
args.stride_dq,
args.stride_dq_acc,
args.nhead_stride_q,
args.nhead_stride_dq,
args.nhead_stride_dq_acc,
args.batch_stride_q,
args.batch_stride_dq,
args.batch_stride_dq_acc,
args.split_stride_dq_acc);
}
......
......@@ -147,6 +147,8 @@ struct FmhaBwdDQDKDVKernel
ck_tile::index_t nhead_stride_do;
ck_tile::index_t nhead_stride_lsed;
ck_tile::index_t nhead_stride_dq_acc;
ck_tile::index_t nhead_stride_dk;
ck_tile::index_t nhead_stride_dv;
ck_tile::index_t batch_stride_lsed;
};
......@@ -301,6 +303,8 @@ struct FmhaBwdDQDKDVKernel
ck_tile::index_t nhead_stride_do,
ck_tile::index_t nhead_stride_lsed,
ck_tile::index_t nhead_stride_dq_acc,
ck_tile::index_t nhead_stride_dk,
ck_tile::index_t nhead_stride_dv,
ck_tile::index_t nhead_stride_dbias,
ck_tile::index_t batch_stride_q,
ck_tile::index_t batch_stride_k,
......@@ -350,6 +354,8 @@ struct FmhaBwdDQDKDVKernel
nhead_stride_do,
nhead_stride_lsed,
nhead_stride_dq_acc,
nhead_stride_dk,
nhead_stride_dv,
batch_stride_lsed}, // args for common karg
{}, // placeholder for bias
{}, // placeholder for dbias
......@@ -452,6 +458,8 @@ struct FmhaBwdDQDKDVKernel
ck_tile::index_t nhead_stride_do,
ck_tile::index_t nhead_stride_lsed,
ck_tile::index_t nhead_stride_dq_acc,
ck_tile::index_t nhead_stride_dk,
ck_tile::index_t nhead_stride_dv,
ck_tile::index_t nhead_stride_dbias,
ck_tile::index_t batch_stride_lsed,
ck_tile::index_t split_stride_dq_acc,
......@@ -491,6 +499,8 @@ struct FmhaBwdDQDKDVKernel
nhead_stride_do,
nhead_stride_lsed,
nhead_stride_dq_acc,
nhead_stride_dk,
nhead_stride_dv,
batch_stride_lsed}, // args for common karg
{}, // placeholder for bias
{}, // placeholder for dbias
......@@ -687,10 +697,10 @@ struct FmhaBwdDQDKDVKernel
static_cast<long_index_t>(i_nhead) * kargs.nhead_stride_do +
batch_offset_do;
KGradDataType* dk_ptr = reinterpret_cast<KGradDataType*>(kargs.dk_ptr) +
static_cast<long_index_t>(i_nhead) * kargs.nhead_stride_k +
static_cast<long_index_t>(i_nhead) * kargs.nhead_stride_dk +
batch_offset_dk;
VGradDataType* dv_ptr = reinterpret_cast<VGradDataType*>(kargs.dv_ptr) +
static_cast<long_index_t>(i_nhead) * kargs.nhead_stride_v +
static_cast<long_index_t>(i_nhead) * kargs.nhead_stride_dv +
batch_offset_dv;
// Q/K/V/LSE/D/dO/dQ/dK/dV DRAM and DRAM window
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment