Commit 7e9d2390 authored by danyao12's avatar danyao12
Browse files

dq_acc stride stuff

parent 224a7b02
......@@ -11,7 +11,7 @@ COMMON_ARGS='-v=1'
set -x
for prec in "fp16" "bf16" ; do
for perm in 0 1 ; do
for hdim in 64 ; do
for hdim in 32 64 128 256 ; do
for mode in 0 1 ; do
for bias in "n" "e" "a"; do
for dbias in 0 1 ; do
......
......@@ -137,6 +137,7 @@ struct FmhaBwdDQDKDVKernel
ck_tile::index_t stride_k;
ck_tile::index_t stride_v;
ck_tile::index_t stride_do;
ck_tile::index_t stride_dq_acc;
ck_tile::index_t stride_dk;
ck_tile::index_t stride_dv;
......@@ -145,6 +146,7 @@ struct FmhaBwdDQDKDVKernel
ck_tile::index_t nhead_stride_v;
ck_tile::index_t nhead_stride_do;
ck_tile::index_t nhead_stride_lsed;
ck_tile::index_t nhead_stride_dq_acc;
ck_tile::index_t batch_stride_lsed;
};
......@@ -236,6 +238,7 @@ struct FmhaBwdDQDKDVKernel
ck_tile::index_t batch_stride_k;
ck_tile::index_t batch_stride_v;
ck_tile::index_t batch_stride_do;
ck_tile::index_t batch_stride_dq_acc;
ck_tile::index_t batch_stride_dk;
ck_tile::index_t batch_stride_dv;
};
......@@ -286,6 +289,7 @@ struct FmhaBwdDQDKDVKernel
ck_tile::index_t stride_bias,
ck_tile::index_t stride_randval,
ck_tile::index_t stride_do,
ck_tile::index_t stride_dq_acc,
ck_tile::index_t stride_dk,
ck_tile::index_t stride_dv,
ck_tile::index_t stride_dbias,
......@@ -296,6 +300,7 @@ struct FmhaBwdDQDKDVKernel
ck_tile::index_t nhead_stride_randval,
ck_tile::index_t nhead_stride_do,
ck_tile::index_t nhead_stride_lsed,
ck_tile::index_t nhead_stride_dq_acc,
ck_tile::index_t nhead_stride_dbias,
ck_tile::index_t batch_stride_q,
ck_tile::index_t batch_stride_k,
......@@ -304,6 +309,7 @@ struct FmhaBwdDQDKDVKernel
ck_tile::index_t batch_stride_randval,
ck_tile::index_t batch_stride_do,
ck_tile::index_t batch_stride_lsed,
ck_tile::index_t batch_stride_dq_acc,
ck_tile::index_t batch_stride_dk,
ck_tile::index_t batch_stride_dv,
ck_tile::index_t batch_stride_dbias,
......@@ -335,6 +341,7 @@ struct FmhaBwdDQDKDVKernel
stride_k,
stride_v,
stride_do,
stride_dq_acc,
stride_dk,
stride_dv,
nhead_stride_q,
......@@ -342,6 +349,7 @@ struct FmhaBwdDQDKDVKernel
nhead_stride_v,
nhead_stride_do,
nhead_stride_lsed,
nhead_stride_dq_acc,
batch_stride_lsed}, // args for common karg
{}, // placeholder for bias
{}, // placeholder for dbias
......@@ -352,6 +360,7 @@ struct FmhaBwdDQDKDVKernel
batch_stride_k,
batch_stride_v,
batch_stride_do,
batch_stride_dq_acc,
batch_stride_dk,
batch_stride_dv};
......@@ -431,6 +440,7 @@ struct FmhaBwdDQDKDVKernel
ck_tile::index_t stride_bias,
ck_tile::index_t stride_randval,
ck_tile::index_t stride_do,
ck_tile::index_t stride_dq_acc,
ck_tile::index_t stride_dk,
ck_tile::index_t stride_dv,
ck_tile::index_t stride_dbias,
......@@ -441,6 +451,7 @@ struct FmhaBwdDQDKDVKernel
ck_tile::index_t nhead_stride_randval,
ck_tile::index_t nhead_stride_do,
ck_tile::index_t nhead_stride_lsed,
ck_tile::index_t nhead_stride_dq_acc,
ck_tile::index_t nhead_stride_dbias,
ck_tile::index_t batch_stride_lsed,
ck_tile::index_t split_stride_dq_acc,
......@@ -471,6 +482,7 @@ struct FmhaBwdDQDKDVKernel
stride_k,
stride_v,
stride_do,
stride_dq_acc,
stride_dk,
stride_dv,
nhead_stride_q,
......@@ -478,6 +490,7 @@ struct FmhaBwdDQDKDVKernel
nhead_stride_v,
nhead_stride_do,
nhead_stride_lsed,
nhead_stride_dq_acc,
batch_stride_lsed}, // args for common karg
{}, // placeholder for bias
{}, // placeholder for dbias
......@@ -571,6 +584,7 @@ struct FmhaBwdDQDKDVKernel
long_index_t batch_offset_randval = 0;
long_index_t batch_offset_do = 0;
long_index_t batch_offset_lsed = 0;
long_index_t batch_offset_dq_acc = 0;
long_index_t batch_offset_dk = 0;
long_index_t batch_offset_dv = 0;
long_index_t batch_offset_dbias = 0;
......@@ -581,13 +595,14 @@ struct FmhaBwdDQDKDVKernel
const long_index_t query_start = kargs.seqstart_q_ptr[i_batch];
const long_index_t key_start = kargs.seqstart_k_ptr[i_batch];
batch_offset_q = query_start * kargs.stride_q;
batch_offset_k = key_start * kargs.stride_k;
batch_offset_v = key_start * kargs.stride_v;
batch_offset_do = query_start * kargs.stride_do;
batch_offset_lsed = static_cast<long_index_t>(i_batch) * kargs.batch_stride_lsed;
batch_offset_dk = key_start * kargs.stride_dk;
batch_offset_dv = key_start * kargs.stride_dv;
batch_offset_q = query_start * kargs.stride_q;
batch_offset_k = key_start * kargs.stride_k;
batch_offset_v = key_start * kargs.stride_v;
batch_offset_do = query_start * kargs.stride_do;
batch_offset_lsed = static_cast<long_index_t>(i_batch) * kargs.batch_stride_lsed;
batch_offset_dq_acc = query_start * kargs.stride_dq_acc;
batch_offset_dk = key_start * kargs.stride_dk;
batch_offset_dv = key_start * kargs.stride_dv;
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
{
batch_offset_bias = query_start * kargs.stride_bias;
......@@ -627,13 +642,14 @@ struct FmhaBwdDQDKDVKernel
}
else
{
batch_offset_q = static_cast<long_index_t>(i_batch) * kargs.batch_stride_q;
batch_offset_k = static_cast<long_index_t>(i_batch) * kargs.batch_stride_k;
batch_offset_v = static_cast<long_index_t>(i_batch) * kargs.batch_stride_v;
batch_offset_do = static_cast<long_index_t>(i_batch) * kargs.batch_stride_do;
batch_offset_lsed = static_cast<long_index_t>(i_batch) * kargs.batch_stride_lsed;
batch_offset_dk = static_cast<long_index_t>(i_batch) * kargs.batch_stride_dk;
batch_offset_dv = static_cast<long_index_t>(i_batch) * kargs.batch_stride_dv;
batch_offset_q = static_cast<long_index_t>(i_batch) * kargs.batch_stride_q;
batch_offset_k = static_cast<long_index_t>(i_batch) * kargs.batch_stride_k;
batch_offset_v = static_cast<long_index_t>(i_batch) * kargs.batch_stride_v;
batch_offset_do = static_cast<long_index_t>(i_batch) * kargs.batch_stride_do;
batch_offset_lsed = static_cast<long_index_t>(i_batch) * kargs.batch_stride_lsed;
batch_offset_dq_acc = static_cast<long_index_t>(i_batch) * kargs.batch_stride_dq_acc;
batch_offset_dk = static_cast<long_index_t>(i_batch) * kargs.batch_stride_dk;
batch_offset_dv = static_cast<long_index_t>(i_batch) * kargs.batch_stride_dv;
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
{
batch_offset_bias = static_cast<long_index_t>(i_batch) * kargs.batch_stride_bias;
......@@ -763,16 +779,16 @@ struct FmhaBwdDQDKDVKernel
{
AccDataType* dq_acc_ptr =
reinterpret_cast<AccDataType*>(kargs.dq_acc_ptr) +
static_cast<long_index_t>(i_nhead_) * kargs.nhead_stride_q +
static_cast<long_index_t>(i_nhead_) * kargs.nhead_stride_dq_acc +
static_cast<long_index_t>(i_tile_n_) * kargs.split_stride_dq_acc +
batch_offset_q;
batch_offset_dq_acc;
auto dq_acc_dram = [&]() {
const auto dq_acc_dram_naive =
make_naive_tensor_view<address_space_enum::global>(
dq_acc_ptr,
make_tuple(kargs.seqlen_q, kargs.hdim_q),
make_tuple(kargs.stride_q, 1),
make_tuple(kargs.stride_dq_acc, 1),
number<FmhaPipeline::kAlignmentQGrad>{},
number<1>{});
......@@ -791,7 +807,8 @@ struct FmhaBwdDQDKDVKernel
{
AccDataType* dq_acc_ptr =
reinterpret_cast<AccDataType*>(kargs.dq_acc_ptr) +
static_cast<long_index_t>(i_nhead_) * kargs.nhead_stride_q + batch_offset_q;
static_cast<long_index_t>(i_nhead_) * kargs.nhead_stride_dq_acc +
batch_offset_dq_acc;
auto dq_acc_dram = [&]() {
const auto dq_acc_dram_naive =
......@@ -799,7 +816,7 @@ struct FmhaBwdDQDKDVKernel
memory_operation_enum::atomic_add>(
dq_acc_ptr,
make_tuple(kargs.seqlen_q, kargs.hdim_q),
make_tuple(kargs.stride_q, 1),
make_tuple(kargs.stride_dq_acc, 1),
number<FmhaPipeline::kAlignmentQGrad>{},
number<1>{});
......@@ -1366,7 +1383,9 @@ struct FmhaBwdConvertQGradKernel
ck_tile::index_t hdim_q;
ck_tile::index_t stride_dq;
ck_tile::index_t stride_dq_acc;
ck_tile::index_t nhead_stride_dq;
ck_tile::index_t nhead_stride_dq_acc;
};
struct FmhaBwdConvertQGradDeterministicKargs
......@@ -1381,6 +1400,7 @@ struct FmhaBwdConvertQGradKernel
FmhaBwdConvertQGradEmptyKargs<0>>
{
ck_tile::index_t batch_stride_dq;
ck_tile::index_t batch_stride_dq_acc;
};
struct FmhaBwdConvertQGradGroupModeKargs
......@@ -1405,13 +1425,25 @@ struct FmhaBwdConvertQGradKernel
ck_tile::index_t seqlen_k,
ck_tile::index_t hdim_q,
ck_tile::index_t stride_dq,
ck_tile::index_t stride_dq_acc,
ck_tile::index_t nhead_stride_dq,
ck_tile::index_t nhead_stride_dq_acc,
ck_tile::index_t batch_stride_dq,
ck_tile::index_t batch_stride_dq_acc,
ck_tile::index_t split_stride_dq_acc)
{
Kargs kargs{{dq_acc_ptr, dq_ptr, seqlen_q, seqlen_k, hdim_q, stride_dq, nhead_stride_dq},
Kargs kargs{{dq_acc_ptr,
dq_ptr,
seqlen_q,
seqlen_k,
hdim_q,
stride_dq,
stride_dq_acc,
nhead_stride_dq,
nhead_stride_dq_acc},
{},
batch_stride_dq};
batch_stride_dq,
batch_stride_dq_acc};
if constexpr(kIsDeterministic)
{
......@@ -1429,7 +1461,9 @@ struct FmhaBwdConvertQGradKernel
const void* seqstart_k_ptr,
ck_tile::index_t hdim_q,
ck_tile::index_t stride_dq,
ck_tile::index_t stride_dq_acc,
ck_tile::index_t nhead_stride_dq,
ck_tile::index_t nhead_stride_dq_acc,
ck_tile::index_t split_stride_dq_acc)
{
Kargs kargs{{dq_acc_ptr,
......@@ -1438,7 +1472,9 @@ struct FmhaBwdConvertQGradKernel
-1, //
hdim_q,
stride_dq,
nhead_stride_dq},
stride_dq_acc,
nhead_stride_dq,
nhead_stride_dq_acc},
{},
reinterpret_cast<const int32_t*>(seqstart_q_ptr),
reinterpret_cast<const int32_t*>(seqstart_k_ptr)};
......@@ -1477,12 +1513,14 @@ struct FmhaBwdConvertQGradKernel
const index_t i_m0 = __builtin_amdgcn_readfirstlane(i_tile_m * kM0);
long_index_t batch_offset_dq = 0;
long_index_t batch_offset_dq = 0;
long_index_t batch_offset_dq_acc = 0;
if constexpr(kIsGroupMode)
{
// get starting offset for each batch
const long_index_t query_start = kargs.seqstart_q_ptr[i_batch];
batch_offset_dq = query_start * kargs.stride_dq;
batch_offset_dq_acc = query_start * kargs.stride_dq_acc;
// get real # queries & # keys under group mode
const auto adjusted_seqstart_q_ptr = kargs.seqstart_q_ptr + i_batch;
......@@ -1501,7 +1539,8 @@ struct FmhaBwdConvertQGradKernel
}
else
{
batch_offset_dq = static_cast<long_index_t>(i_batch) * kargs.batch_stride_dq;
batch_offset_dq = static_cast<long_index_t>(i_batch) * kargs.batch_stride_dq;
batch_offset_dq_acc = static_cast<long_index_t>(i_batch) * kargs.batch_stride_dq_acc;
}
// for simplicity, batch stride we just modify the pointer
......@@ -1515,14 +1554,15 @@ struct FmhaBwdConvertQGradKernel
{
const AccDataType* dq_acc_ptr =
reinterpret_cast<const AccDataType*>(kargs.dq_acc_ptr) +
static_cast<long_index_t>(i_nhead_) * (kargs.nhead_stride_dq) + batch_offset_dq;
static_cast<long_index_t>(i_nhead_) * (kargs.nhead_stride_dq_acc) +
batch_offset_dq_acc;
const index_t nsplits = ck_tile::integer_divide_ceil(kargs.seqlen_k, kN0);
auto dq_acc_dram_naive = make_naive_tensor_view<address_space_enum::global>(
dq_acc_ptr,
make_tuple(nsplits, kargs.seqlen_q, kargs.hdim_q),
make_tuple(kargs.split_stride_dq_acc, kargs.stride_dq, 1),
make_tuple(kargs.split_stride_dq_acc, kargs.stride_dq_acc, 1),
number<FmhaBwdConvertQGrad::kAlignmentQGradAcc>{},
number<1>{});
return pad_tensor_view(dq_acc_dram_naive,
......@@ -1533,12 +1573,13 @@ struct FmhaBwdConvertQGradKernel
{
const AccDataType* dq_acc_ptr =
reinterpret_cast<const AccDataType*>(kargs.dq_acc_ptr) +
static_cast<long_index_t>(i_nhead_) * (kargs.nhead_stride_dq) + batch_offset_dq;
static_cast<long_index_t>(i_nhead_) * (kargs.nhead_stride_dq_acc) +
batch_offset_dq_acc;
auto dq_acc_dram_naive = make_naive_tensor_view<address_space_enum::global>(
dq_acc_ptr,
make_tuple(kargs.seqlen_q, kargs.hdim_q),
make_tuple(kargs.stride_dq, 1),
make_tuple(kargs.stride_dq_acc, 1),
number<FmhaBwdConvertQGrad::kAlignmentQGradAcc>{},
number<1>{});
return pad_tensor_view(dq_acc_dram_naive,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment