Commit 7e9d2390 authored by danyao12's avatar danyao12
Browse files

dq_acc stride stuff

parent 224a7b02
...@@ -11,7 +11,7 @@ COMMON_ARGS='-v=1' ...@@ -11,7 +11,7 @@ COMMON_ARGS='-v=1'
set -x set -x
for prec in "fp16" "bf16" ; do for prec in "fp16" "bf16" ; do
for perm in 0 1 ; do for perm in 0 1 ; do
for hdim in 64 ; do for hdim in 32 64 128 256 ; do
for mode in 0 1 ; do for mode in 0 1 ; do
for bias in "n" "e" "a"; do for bias in "n" "e" "a"; do
for dbias in 0 1 ; do for dbias in 0 1 ; do
......
...@@ -137,6 +137,7 @@ struct FmhaBwdDQDKDVKernel ...@@ -137,6 +137,7 @@ struct FmhaBwdDQDKDVKernel
ck_tile::index_t stride_k; ck_tile::index_t stride_k;
ck_tile::index_t stride_v; ck_tile::index_t stride_v;
ck_tile::index_t stride_do; ck_tile::index_t stride_do;
ck_tile::index_t stride_dq_acc;
ck_tile::index_t stride_dk; ck_tile::index_t stride_dk;
ck_tile::index_t stride_dv; ck_tile::index_t stride_dv;
...@@ -145,6 +146,7 @@ struct FmhaBwdDQDKDVKernel ...@@ -145,6 +146,7 @@ struct FmhaBwdDQDKDVKernel
ck_tile::index_t nhead_stride_v; ck_tile::index_t nhead_stride_v;
ck_tile::index_t nhead_stride_do; ck_tile::index_t nhead_stride_do;
ck_tile::index_t nhead_stride_lsed; ck_tile::index_t nhead_stride_lsed;
ck_tile::index_t nhead_stride_dq_acc;
ck_tile::index_t batch_stride_lsed; ck_tile::index_t batch_stride_lsed;
}; };
...@@ -236,6 +238,7 @@ struct FmhaBwdDQDKDVKernel ...@@ -236,6 +238,7 @@ struct FmhaBwdDQDKDVKernel
ck_tile::index_t batch_stride_k; ck_tile::index_t batch_stride_k;
ck_tile::index_t batch_stride_v; ck_tile::index_t batch_stride_v;
ck_tile::index_t batch_stride_do; ck_tile::index_t batch_stride_do;
ck_tile::index_t batch_stride_dq_acc;
ck_tile::index_t batch_stride_dk; ck_tile::index_t batch_stride_dk;
ck_tile::index_t batch_stride_dv; ck_tile::index_t batch_stride_dv;
}; };
...@@ -286,6 +289,7 @@ struct FmhaBwdDQDKDVKernel ...@@ -286,6 +289,7 @@ struct FmhaBwdDQDKDVKernel
ck_tile::index_t stride_bias, ck_tile::index_t stride_bias,
ck_tile::index_t stride_randval, ck_tile::index_t stride_randval,
ck_tile::index_t stride_do, ck_tile::index_t stride_do,
ck_tile::index_t stride_dq_acc,
ck_tile::index_t stride_dk, ck_tile::index_t stride_dk,
ck_tile::index_t stride_dv, ck_tile::index_t stride_dv,
ck_tile::index_t stride_dbias, ck_tile::index_t stride_dbias,
...@@ -296,6 +300,7 @@ struct FmhaBwdDQDKDVKernel ...@@ -296,6 +300,7 @@ struct FmhaBwdDQDKDVKernel
ck_tile::index_t nhead_stride_randval, ck_tile::index_t nhead_stride_randval,
ck_tile::index_t nhead_stride_do, ck_tile::index_t nhead_stride_do,
ck_tile::index_t nhead_stride_lsed, ck_tile::index_t nhead_stride_lsed,
ck_tile::index_t nhead_stride_dq_acc,
ck_tile::index_t nhead_stride_dbias, ck_tile::index_t nhead_stride_dbias,
ck_tile::index_t batch_stride_q, ck_tile::index_t batch_stride_q,
ck_tile::index_t batch_stride_k, ck_tile::index_t batch_stride_k,
...@@ -304,6 +309,7 @@ struct FmhaBwdDQDKDVKernel ...@@ -304,6 +309,7 @@ struct FmhaBwdDQDKDVKernel
ck_tile::index_t batch_stride_randval, ck_tile::index_t batch_stride_randval,
ck_tile::index_t batch_stride_do, ck_tile::index_t batch_stride_do,
ck_tile::index_t batch_stride_lsed, ck_tile::index_t batch_stride_lsed,
ck_tile::index_t batch_stride_dq_acc,
ck_tile::index_t batch_stride_dk, ck_tile::index_t batch_stride_dk,
ck_tile::index_t batch_stride_dv, ck_tile::index_t batch_stride_dv,
ck_tile::index_t batch_stride_dbias, ck_tile::index_t batch_stride_dbias,
...@@ -335,6 +341,7 @@ struct FmhaBwdDQDKDVKernel ...@@ -335,6 +341,7 @@ struct FmhaBwdDQDKDVKernel
stride_k, stride_k,
stride_v, stride_v,
stride_do, stride_do,
stride_dq_acc,
stride_dk, stride_dk,
stride_dv, stride_dv,
nhead_stride_q, nhead_stride_q,
...@@ -342,6 +349,7 @@ struct FmhaBwdDQDKDVKernel ...@@ -342,6 +349,7 @@ struct FmhaBwdDQDKDVKernel
nhead_stride_v, nhead_stride_v,
nhead_stride_do, nhead_stride_do,
nhead_stride_lsed, nhead_stride_lsed,
nhead_stride_dq_acc,
batch_stride_lsed}, // args for common karg batch_stride_lsed}, // args for common karg
{}, // placeholder for bias {}, // placeholder for bias
{}, // placeholder for dbias {}, // placeholder for dbias
...@@ -352,6 +360,7 @@ struct FmhaBwdDQDKDVKernel ...@@ -352,6 +360,7 @@ struct FmhaBwdDQDKDVKernel
batch_stride_k, batch_stride_k,
batch_stride_v, batch_stride_v,
batch_stride_do, batch_stride_do,
batch_stride_dq_acc,
batch_stride_dk, batch_stride_dk,
batch_stride_dv}; batch_stride_dv};
...@@ -431,6 +440,7 @@ struct FmhaBwdDQDKDVKernel ...@@ -431,6 +440,7 @@ struct FmhaBwdDQDKDVKernel
ck_tile::index_t stride_bias, ck_tile::index_t stride_bias,
ck_tile::index_t stride_randval, ck_tile::index_t stride_randval,
ck_tile::index_t stride_do, ck_tile::index_t stride_do,
ck_tile::index_t stride_dq_acc,
ck_tile::index_t stride_dk, ck_tile::index_t stride_dk,
ck_tile::index_t stride_dv, ck_tile::index_t stride_dv,
ck_tile::index_t stride_dbias, ck_tile::index_t stride_dbias,
...@@ -441,6 +451,7 @@ struct FmhaBwdDQDKDVKernel ...@@ -441,6 +451,7 @@ struct FmhaBwdDQDKDVKernel
ck_tile::index_t nhead_stride_randval, ck_tile::index_t nhead_stride_randval,
ck_tile::index_t nhead_stride_do, ck_tile::index_t nhead_stride_do,
ck_tile::index_t nhead_stride_lsed, ck_tile::index_t nhead_stride_lsed,
ck_tile::index_t nhead_stride_dq_acc,
ck_tile::index_t nhead_stride_dbias, ck_tile::index_t nhead_stride_dbias,
ck_tile::index_t batch_stride_lsed, ck_tile::index_t batch_stride_lsed,
ck_tile::index_t split_stride_dq_acc, ck_tile::index_t split_stride_dq_acc,
...@@ -471,6 +482,7 @@ struct FmhaBwdDQDKDVKernel ...@@ -471,6 +482,7 @@ struct FmhaBwdDQDKDVKernel
stride_k, stride_k,
stride_v, stride_v,
stride_do, stride_do,
stride_dq_acc,
stride_dk, stride_dk,
stride_dv, stride_dv,
nhead_stride_q, nhead_stride_q,
...@@ -478,6 +490,7 @@ struct FmhaBwdDQDKDVKernel ...@@ -478,6 +490,7 @@ struct FmhaBwdDQDKDVKernel
nhead_stride_v, nhead_stride_v,
nhead_stride_do, nhead_stride_do,
nhead_stride_lsed, nhead_stride_lsed,
nhead_stride_dq_acc,
batch_stride_lsed}, // args for common karg batch_stride_lsed}, // args for common karg
{}, // placeholder for bias {}, // placeholder for bias
{}, // placeholder for dbias {}, // placeholder for dbias
...@@ -571,6 +584,7 @@ struct FmhaBwdDQDKDVKernel ...@@ -571,6 +584,7 @@ struct FmhaBwdDQDKDVKernel
long_index_t batch_offset_randval = 0; long_index_t batch_offset_randval = 0;
long_index_t batch_offset_do = 0; long_index_t batch_offset_do = 0;
long_index_t batch_offset_lsed = 0; long_index_t batch_offset_lsed = 0;
long_index_t batch_offset_dq_acc = 0;
long_index_t batch_offset_dk = 0; long_index_t batch_offset_dk = 0;
long_index_t batch_offset_dv = 0; long_index_t batch_offset_dv = 0;
long_index_t batch_offset_dbias = 0; long_index_t batch_offset_dbias = 0;
...@@ -581,13 +595,14 @@ struct FmhaBwdDQDKDVKernel ...@@ -581,13 +595,14 @@ struct FmhaBwdDQDKDVKernel
const long_index_t query_start = kargs.seqstart_q_ptr[i_batch]; const long_index_t query_start = kargs.seqstart_q_ptr[i_batch];
const long_index_t key_start = kargs.seqstart_k_ptr[i_batch]; const long_index_t key_start = kargs.seqstart_k_ptr[i_batch];
batch_offset_q = query_start * kargs.stride_q; batch_offset_q = query_start * kargs.stride_q;
batch_offset_k = key_start * kargs.stride_k; batch_offset_k = key_start * kargs.stride_k;
batch_offset_v = key_start * kargs.stride_v; batch_offset_v = key_start * kargs.stride_v;
batch_offset_do = query_start * kargs.stride_do; batch_offset_do = query_start * kargs.stride_do;
batch_offset_lsed = static_cast<long_index_t>(i_batch) * kargs.batch_stride_lsed; batch_offset_lsed = static_cast<long_index_t>(i_batch) * kargs.batch_stride_lsed;
batch_offset_dk = key_start * kargs.stride_dk; batch_offset_dq_acc = query_start * kargs.stride_dq_acc;
batch_offset_dv = key_start * kargs.stride_dv; batch_offset_dk = key_start * kargs.stride_dk;
batch_offset_dv = key_start * kargs.stride_dv;
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
{ {
batch_offset_bias = query_start * kargs.stride_bias; batch_offset_bias = query_start * kargs.stride_bias;
...@@ -627,13 +642,14 @@ struct FmhaBwdDQDKDVKernel ...@@ -627,13 +642,14 @@ struct FmhaBwdDQDKDVKernel
} }
else else
{ {
batch_offset_q = static_cast<long_index_t>(i_batch) * kargs.batch_stride_q; batch_offset_q = static_cast<long_index_t>(i_batch) * kargs.batch_stride_q;
batch_offset_k = static_cast<long_index_t>(i_batch) * kargs.batch_stride_k; batch_offset_k = static_cast<long_index_t>(i_batch) * kargs.batch_stride_k;
batch_offset_v = static_cast<long_index_t>(i_batch) * kargs.batch_stride_v; batch_offset_v = static_cast<long_index_t>(i_batch) * kargs.batch_stride_v;
batch_offset_do = static_cast<long_index_t>(i_batch) * kargs.batch_stride_do; batch_offset_do = static_cast<long_index_t>(i_batch) * kargs.batch_stride_do;
batch_offset_lsed = static_cast<long_index_t>(i_batch) * kargs.batch_stride_lsed; batch_offset_lsed = static_cast<long_index_t>(i_batch) * kargs.batch_stride_lsed;
batch_offset_dk = static_cast<long_index_t>(i_batch) * kargs.batch_stride_dk; batch_offset_dq_acc = static_cast<long_index_t>(i_batch) * kargs.batch_stride_dq_acc;
batch_offset_dv = static_cast<long_index_t>(i_batch) * kargs.batch_stride_dv; batch_offset_dk = static_cast<long_index_t>(i_batch) * kargs.batch_stride_dk;
batch_offset_dv = static_cast<long_index_t>(i_batch) * kargs.batch_stride_dv;
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
{ {
batch_offset_bias = static_cast<long_index_t>(i_batch) * kargs.batch_stride_bias; batch_offset_bias = static_cast<long_index_t>(i_batch) * kargs.batch_stride_bias;
...@@ -763,16 +779,16 @@ struct FmhaBwdDQDKDVKernel ...@@ -763,16 +779,16 @@ struct FmhaBwdDQDKDVKernel
{ {
AccDataType* dq_acc_ptr = AccDataType* dq_acc_ptr =
reinterpret_cast<AccDataType*>(kargs.dq_acc_ptr) + reinterpret_cast<AccDataType*>(kargs.dq_acc_ptr) +
static_cast<long_index_t>(i_nhead_) * kargs.nhead_stride_q + static_cast<long_index_t>(i_nhead_) * kargs.nhead_stride_dq_acc +
static_cast<long_index_t>(i_tile_n_) * kargs.split_stride_dq_acc + static_cast<long_index_t>(i_tile_n_) * kargs.split_stride_dq_acc +
batch_offset_q; batch_offset_dq_acc;
auto dq_acc_dram = [&]() { auto dq_acc_dram = [&]() {
const auto dq_acc_dram_naive = const auto dq_acc_dram_naive =
make_naive_tensor_view<address_space_enum::global>( make_naive_tensor_view<address_space_enum::global>(
dq_acc_ptr, dq_acc_ptr,
make_tuple(kargs.seqlen_q, kargs.hdim_q), make_tuple(kargs.seqlen_q, kargs.hdim_q),
make_tuple(kargs.stride_q, 1), make_tuple(kargs.stride_dq_acc, 1),
number<FmhaPipeline::kAlignmentQGrad>{}, number<FmhaPipeline::kAlignmentQGrad>{},
number<1>{}); number<1>{});
...@@ -791,7 +807,8 @@ struct FmhaBwdDQDKDVKernel ...@@ -791,7 +807,8 @@ struct FmhaBwdDQDKDVKernel
{ {
AccDataType* dq_acc_ptr = AccDataType* dq_acc_ptr =
reinterpret_cast<AccDataType*>(kargs.dq_acc_ptr) + reinterpret_cast<AccDataType*>(kargs.dq_acc_ptr) +
static_cast<long_index_t>(i_nhead_) * kargs.nhead_stride_q + batch_offset_q; static_cast<long_index_t>(i_nhead_) * kargs.nhead_stride_dq_acc +
batch_offset_dq_acc;
auto dq_acc_dram = [&]() { auto dq_acc_dram = [&]() {
const auto dq_acc_dram_naive = const auto dq_acc_dram_naive =
...@@ -799,7 +816,7 @@ struct FmhaBwdDQDKDVKernel ...@@ -799,7 +816,7 @@ struct FmhaBwdDQDKDVKernel
memory_operation_enum::atomic_add>( memory_operation_enum::atomic_add>(
dq_acc_ptr, dq_acc_ptr,
make_tuple(kargs.seqlen_q, kargs.hdim_q), make_tuple(kargs.seqlen_q, kargs.hdim_q),
make_tuple(kargs.stride_q, 1), make_tuple(kargs.stride_dq_acc, 1),
number<FmhaPipeline::kAlignmentQGrad>{}, number<FmhaPipeline::kAlignmentQGrad>{},
number<1>{}); number<1>{});
...@@ -1366,7 +1383,9 @@ struct FmhaBwdConvertQGradKernel ...@@ -1366,7 +1383,9 @@ struct FmhaBwdConvertQGradKernel
ck_tile::index_t hdim_q; ck_tile::index_t hdim_q;
ck_tile::index_t stride_dq; ck_tile::index_t stride_dq;
ck_tile::index_t stride_dq_acc;
ck_tile::index_t nhead_stride_dq; ck_tile::index_t nhead_stride_dq;
ck_tile::index_t nhead_stride_dq_acc;
}; };
struct FmhaBwdConvertQGradDeterministicKargs struct FmhaBwdConvertQGradDeterministicKargs
...@@ -1381,6 +1400,7 @@ struct FmhaBwdConvertQGradKernel ...@@ -1381,6 +1400,7 @@ struct FmhaBwdConvertQGradKernel
FmhaBwdConvertQGradEmptyKargs<0>> FmhaBwdConvertQGradEmptyKargs<0>>
{ {
ck_tile::index_t batch_stride_dq; ck_tile::index_t batch_stride_dq;
ck_tile::index_t batch_stride_dq_acc;
}; };
struct FmhaBwdConvertQGradGroupModeKargs struct FmhaBwdConvertQGradGroupModeKargs
...@@ -1405,13 +1425,25 @@ struct FmhaBwdConvertQGradKernel ...@@ -1405,13 +1425,25 @@ struct FmhaBwdConvertQGradKernel
ck_tile::index_t seqlen_k, ck_tile::index_t seqlen_k,
ck_tile::index_t hdim_q, ck_tile::index_t hdim_q,
ck_tile::index_t stride_dq, ck_tile::index_t stride_dq,
ck_tile::index_t stride_dq_acc,
ck_tile::index_t nhead_stride_dq, ck_tile::index_t nhead_stride_dq,
ck_tile::index_t nhead_stride_dq_acc,
ck_tile::index_t batch_stride_dq, ck_tile::index_t batch_stride_dq,
ck_tile::index_t batch_stride_dq_acc,
ck_tile::index_t split_stride_dq_acc) ck_tile::index_t split_stride_dq_acc)
{ {
Kargs kargs{{dq_acc_ptr, dq_ptr, seqlen_q, seqlen_k, hdim_q, stride_dq, nhead_stride_dq}, Kargs kargs{{dq_acc_ptr,
dq_ptr,
seqlen_q,
seqlen_k,
hdim_q,
stride_dq,
stride_dq_acc,
nhead_stride_dq,
nhead_stride_dq_acc},
{}, {},
batch_stride_dq}; batch_stride_dq,
batch_stride_dq_acc};
if constexpr(kIsDeterministic) if constexpr(kIsDeterministic)
{ {
...@@ -1429,7 +1461,9 @@ struct FmhaBwdConvertQGradKernel ...@@ -1429,7 +1461,9 @@ struct FmhaBwdConvertQGradKernel
const void* seqstart_k_ptr, const void* seqstart_k_ptr,
ck_tile::index_t hdim_q, ck_tile::index_t hdim_q,
ck_tile::index_t stride_dq, ck_tile::index_t stride_dq,
ck_tile::index_t stride_dq_acc,
ck_tile::index_t nhead_stride_dq, ck_tile::index_t nhead_stride_dq,
ck_tile::index_t nhead_stride_dq_acc,
ck_tile::index_t split_stride_dq_acc) ck_tile::index_t split_stride_dq_acc)
{ {
Kargs kargs{{dq_acc_ptr, Kargs kargs{{dq_acc_ptr,
...@@ -1438,7 +1472,9 @@ struct FmhaBwdConvertQGradKernel ...@@ -1438,7 +1472,9 @@ struct FmhaBwdConvertQGradKernel
-1, // -1, //
hdim_q, hdim_q,
stride_dq, stride_dq,
nhead_stride_dq}, stride_dq_acc,
nhead_stride_dq,
nhead_stride_dq_acc},
{}, {},
reinterpret_cast<const int32_t*>(seqstart_q_ptr), reinterpret_cast<const int32_t*>(seqstart_q_ptr),
reinterpret_cast<const int32_t*>(seqstart_k_ptr)}; reinterpret_cast<const int32_t*>(seqstart_k_ptr)};
...@@ -1477,12 +1513,14 @@ struct FmhaBwdConvertQGradKernel ...@@ -1477,12 +1513,14 @@ struct FmhaBwdConvertQGradKernel
const index_t i_m0 = __builtin_amdgcn_readfirstlane(i_tile_m * kM0); const index_t i_m0 = __builtin_amdgcn_readfirstlane(i_tile_m * kM0);
long_index_t batch_offset_dq = 0; long_index_t batch_offset_dq = 0;
long_index_t batch_offset_dq_acc = 0;
if constexpr(kIsGroupMode) if constexpr(kIsGroupMode)
{ {
// get starting offset for each batch // get starting offset for each batch
const long_index_t query_start = kargs.seqstart_q_ptr[i_batch]; const long_index_t query_start = kargs.seqstart_q_ptr[i_batch];
batch_offset_dq = query_start * kargs.stride_dq; batch_offset_dq = query_start * kargs.stride_dq;
batch_offset_dq_acc = query_start * kargs.stride_dq_acc;
// get real # queries & # keys under group mode // get real # queries & # keys under group mode
const auto adjusted_seqstart_q_ptr = kargs.seqstart_q_ptr + i_batch; const auto adjusted_seqstart_q_ptr = kargs.seqstart_q_ptr + i_batch;
...@@ -1501,7 +1539,8 @@ struct FmhaBwdConvertQGradKernel ...@@ -1501,7 +1539,8 @@ struct FmhaBwdConvertQGradKernel
} }
else else
{ {
batch_offset_dq = static_cast<long_index_t>(i_batch) * kargs.batch_stride_dq; batch_offset_dq = static_cast<long_index_t>(i_batch) * kargs.batch_stride_dq;
batch_offset_dq_acc = static_cast<long_index_t>(i_batch) * kargs.batch_stride_dq_acc;
} }
// for simplicity, batch stride we just modify the pointer // for simplicity, batch stride we just modify the pointer
...@@ -1515,14 +1554,15 @@ struct FmhaBwdConvertQGradKernel ...@@ -1515,14 +1554,15 @@ struct FmhaBwdConvertQGradKernel
{ {
const AccDataType* dq_acc_ptr = const AccDataType* dq_acc_ptr =
reinterpret_cast<const AccDataType*>(kargs.dq_acc_ptr) + reinterpret_cast<const AccDataType*>(kargs.dq_acc_ptr) +
static_cast<long_index_t>(i_nhead_) * (kargs.nhead_stride_dq) + batch_offset_dq; static_cast<long_index_t>(i_nhead_) * (kargs.nhead_stride_dq_acc) +
batch_offset_dq_acc;
const index_t nsplits = ck_tile::integer_divide_ceil(kargs.seqlen_k, kN0); const index_t nsplits = ck_tile::integer_divide_ceil(kargs.seqlen_k, kN0);
auto dq_acc_dram_naive = make_naive_tensor_view<address_space_enum::global>( auto dq_acc_dram_naive = make_naive_tensor_view<address_space_enum::global>(
dq_acc_ptr, dq_acc_ptr,
make_tuple(nsplits, kargs.seqlen_q, kargs.hdim_q), make_tuple(nsplits, kargs.seqlen_q, kargs.hdim_q),
make_tuple(kargs.split_stride_dq_acc, kargs.stride_dq, 1), make_tuple(kargs.split_stride_dq_acc, kargs.stride_dq_acc, 1),
number<FmhaBwdConvertQGrad::kAlignmentQGradAcc>{}, number<FmhaBwdConvertQGrad::kAlignmentQGradAcc>{},
number<1>{}); number<1>{});
return pad_tensor_view(dq_acc_dram_naive, return pad_tensor_view(dq_acc_dram_naive,
...@@ -1533,12 +1573,13 @@ struct FmhaBwdConvertQGradKernel ...@@ -1533,12 +1573,13 @@ struct FmhaBwdConvertQGradKernel
{ {
const AccDataType* dq_acc_ptr = const AccDataType* dq_acc_ptr =
reinterpret_cast<const AccDataType*>(kargs.dq_acc_ptr) + reinterpret_cast<const AccDataType*>(kargs.dq_acc_ptr) +
static_cast<long_index_t>(i_nhead_) * (kargs.nhead_stride_dq) + batch_offset_dq; static_cast<long_index_t>(i_nhead_) * (kargs.nhead_stride_dq_acc) +
batch_offset_dq_acc;
auto dq_acc_dram_naive = make_naive_tensor_view<address_space_enum::global>( auto dq_acc_dram_naive = make_naive_tensor_view<address_space_enum::global>(
dq_acc_ptr, dq_acc_ptr,
make_tuple(kargs.seqlen_q, kargs.hdim_q), make_tuple(kargs.seqlen_q, kargs.hdim_q),
make_tuple(kargs.stride_dq, 1), make_tuple(kargs.stride_dq_acc, 1),
number<FmhaBwdConvertQGrad::kAlignmentQGradAcc>{}, number<FmhaBwdConvertQGrad::kAlignmentQGradAcc>{},
number<1>{}); number<1>{});
return pad_tensor_view(dq_acc_dram_naive, return pad_tensor_view(dq_acc_dram_naive,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment