Commit 78f33529 authored by danyao12's avatar danyao12
Browse files

no_coex update

parent 8ac3eb39
......@@ -403,12 +403,22 @@ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config&
if(t.data_type.compare("fp16") == 0){{
if(t.mask_type == mask_enum::no_mask){{
if((t.is_asm_atomic_fp32 == true) && (a.nhead_stride_dq_acc > a.stride_dq_acc /*dq_acc only support BHSD*/)){{
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, false, false>;
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, false, false, false>;
const std::string bwd_ext_name = "bwd_ext_fp16_a32";
bool io_perm = a.nhead_stride_q > a.stride_q;
r = fmha_ext_bwd_<dot_do_o_trait_, convert_dq_trait_>(s, a, bwd_fp16_a32, bwd_ext_name, io_perm);
return r;
if(t.is_asm_no_coex == true){{
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, false, false>;
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, false, false, false>;
const std::string bwd_ext_name = "bwd_ext_fp16_nocoex_a32";
bool io_perm = a.nhead_stride_q > a.stride_q;
r = fmha_ext_bwd_<dot_do_o_trait_, convert_dq_trait_>(s, a, bwd_fp16_nocoex_a32, bwd_ext_name, io_perm);
return r;
}}
else{{
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, false, false>;
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, false, false, false>;
const std::string bwd_ext_name = "bwd_ext_fp16_a32";
bool io_perm = a.nhead_stride_q > a.stride_q;
r = fmha_ext_bwd_<dot_do_o_trait_, convert_dq_trait_>(s, a, bwd_fp16_a32, bwd_ext_name, io_perm);
return r;
}}
}}
else if(t.is_asm_atomic_fp32 == false){{
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, false, false>;
......@@ -420,12 +430,22 @@ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config&
}}
else if((t.mask_type != mask_enum::no_mask) && ((a.window_size_left == -1) && (a.window_size_right == 0))){{
if((t.is_asm_atomic_fp32 == true) && (a.nhead_stride_dq_acc > a.stride_dq_acc /*dq_acc only support BHSD*/)){{
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, false, false>;
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, false, false, false>;
const std::string bwd_ext_name = "bwd_ext_fp16_causal_a32";
bool io_perm = a.nhead_stride_q > a.stride_q;
r = fmha_ext_bwd_<dot_do_o_trait_, convert_dq_trait_>(s, a, bwd_fp16_causal_a32, bwd_ext_name, io_perm);
return r;
if(t.is_asm_no_coex == true){{
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, false, false>;
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, false, false, false>;
const std::string bwd_ext_name = "bwd_ext_fp16_nocoex_causal_a32";
bool io_perm = a.nhead_stride_q > a.stride_q;
r = fmha_ext_bwd_<dot_do_o_trait_, convert_dq_trait_>(s, a, bwd_fp16_nocoex_causal_a32, bwd_ext_name, io_perm);
return r;
}}
else{{
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, false, false>;
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, false, false, false>;
const std::string bwd_ext_name = "bwd_ext_fp16_causal_a32";
bool io_perm = a.nhead_stride_q > a.stride_q;
r = fmha_ext_bwd_<dot_do_o_trait_, convert_dq_trait_>(s, a, bwd_fp16_causal_a32, bwd_ext_name, io_perm);
return r;
}}
}}
else if(t.is_asm_atomic_fp32 == false){{
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, false, false>;
......@@ -439,12 +459,22 @@ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config&
else if(t.data_type.compare("bf16") == 0){{
if(t.mask_type == mask_enum::no_mask){{
if((t.is_asm_atomic_fp32 == true) && (a.nhead_stride_dq_acc > a.stride_dq_acc /*dq_acc only support BHSD*/)){{
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, false, false>;
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, false, false, false>;
const std::string bwd_ext_name = "bwd_ext_bf16_a32";
bool io_perm = a.nhead_stride_q > a.stride_q;
r = fmha_ext_bwd_<dot_do_o_trait_, convert_dq_trait_>(s, a, bwd_bf16_a32, bwd_ext_name, io_perm);
return r;
if(t.is_asm_no_coex == true){{
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, false, false>;
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, false, false, false>;
const std::string bwd_ext_name = "bwd_ext_bf16_nocoex_a32";
bool io_perm = a.nhead_stride_q > a.stride_q;
r = fmha_ext_bwd_<dot_do_o_trait_, convert_dq_trait_>(s, a, bwd_bf16_nocoex_a32, bwd_ext_name, io_perm);
return r;
}}
else{{
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, false, false>;
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, false, false, false>;
const std::string bwd_ext_name = "bwd_ext_bf16_a32";
bool io_perm = a.nhead_stride_q > a.stride_q;
r = fmha_ext_bwd_<dot_do_o_trait_, convert_dq_trait_>(s, a, bwd_bf16_a32, bwd_ext_name, io_perm);
return r;
}}
}}
else if(t.is_asm_atomic_fp32 == false){{
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, false, false>;
......@@ -456,12 +486,22 @@ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config&
}}
else if((t.mask_type != mask_enum::no_mask) && ((a.window_size_left == -1) && (a.window_size_right == 0))){{
if((t.is_asm_atomic_fp32 == true) && (a.nhead_stride_dq_acc > a.stride_dq_acc /*dq_acc only support BHSD*/)){{
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, false, false>;
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, false, false, false>;
const std::string bwd_ext_name = "bwd_ext_bf16_causal_a32";
bool io_perm = a.nhead_stride_q > a.stride_q;
r = fmha_ext_bwd_<dot_do_o_trait_, convert_dq_trait_>(s, a, bwd_bf16_causal_a32, bwd_ext_name, io_perm);
return r;
if(t.is_asm_no_coex == true){{
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, false, false>;
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, false, false, false>;
const std::string bwd_ext_name = "bwd_ext_bf16_nocoex_causal_a32";
bool io_perm = a.nhead_stride_q > a.stride_q;
r = fmha_ext_bwd_<dot_do_o_trait_, convert_dq_trait_>(s, a, bwd_bf16_nocoex_causal_a32, bwd_ext_name, io_perm);
return r;
}}
else{{
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, false, false>;
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, false, false, false>;
const std::string bwd_ext_name = "bwd_ext_bf16_causal_a32";
bool io_perm = a.nhead_stride_q > a.stride_q;
r = fmha_ext_bwd_<dot_do_o_trait_, convert_dq_trait_>(s, a, bwd_bf16_causal_a32, bwd_ext_name, io_perm);
return r;
}}
}}
else if(t.is_asm_atomic_fp32 == false){{
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, false, false>;
......
......@@ -95,7 +95,8 @@ auto create_args(int argc, char* argv[])
.insert("ext_asm", "0", "if set to 1, some cases will call the ext asm dqdkdv kernel")
.insert("asm_atomic_fp32",
"1",
"if set to 0, atomic fp16/bf16 is used when calling the ext asm dqdkdv kernel");
"if set to 0 will use atomic fp16/bf16(w/o convert_dq kernel) when ext_asm is set to 1")
.insert("asm_no_coex", "0", "if set to 1 will use non-coexectuion kernel when ext_asm is set to 1");
bool result = arg_parser.parse(argc, argv);
return std::make_tuple(result, arg_parser);
......@@ -186,6 +187,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
bool deterministic = arg_parser.get_bool("deterministic");
bool ext_asm = arg_parser.get_bool("ext_asm");
bool asm_atomic_fp32 = arg_parser.get_bool("asm_atomic_fp32");
bool asm_no_coex = arg_parser.get_bool("asm_no_coex");
ck_tile::stream_config stream_config{nullptr,
true,
......@@ -312,9 +314,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
? get_lengths(i_perm, shape_batch, nhead, shape_seqlen_q, max_seqlen_k)
: std::array<ck_tile::index_t, 4>{1, 1, 1, 1} /* dummy shape for simplifying code */);
ck_tile::HostTensor<AccDataType> dq_acc_host(
i_perm
? std::array<ck_tile::index_t, 5>{nsplits, shape_batch, nhead, shape_seqlen_q, hdim_q}
: std::array<ck_tile::index_t, 5>{nsplits, shape_batch, shape_seqlen_q, nhead, hdim_q});
std::array<ck_tile::index_t, 5>{nsplits, shape_batch, nhead, shape_seqlen_q, hdim_q});
if(init_method == 0)
{
......@@ -424,7 +424,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
s_randval,
deterministic,
ext_asm,
asm_atomic_fp32};
asm_atomic_fp32,
asm_no_coex};
auto fmha_args = [&]() {
assert(nhead % nhead_k == 0);
/// NOTE: we broadcast bias from [1, 1, seqlen_q, seqlen_k] to [batch, nhead, seqlen_q,
......@@ -438,6 +439,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
const ck_tile::index_t stride_o = (o_perm ? hdim_v : nhead * hdim_v);
const ck_tile::index_t stride_randval = (max_seqlen_k);
const ck_tile::index_t stride_do = (o_perm ? hdim_v : nhead * hdim_v);
const ck_tile::index_t stride_dq_acc = hdim_q;
const ck_tile::index_t stride_dk = (i_perm ? hdim_q : nhead * hdim_q);
const ck_tile::index_t stride_dv = (i_perm ? hdim_v : nhead * hdim_v);
const ck_tile::index_t stride_dbias = (i_perm ? max_seqlen_k : nhead * max_seqlen_k);
......@@ -450,6 +452,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
const ck_tile::index_t nhead_stride_randval = (shape_seqlen_q * max_seqlen_k);
const ck_tile::index_t nhead_stride_do = (o_perm ? shape_seqlen_q * hdim_v : hdim_v);
const ck_tile::index_t nhead_stride_lsed = shape_seqlen_q;
const ck_tile::index_t nhead_stride_dq_acc = shape_seqlen_q * hdim_q;
const ck_tile::index_t nhead_stride_dbias =
(i_perm ? shape_seqlen_q * max_seqlen_k : max_seqlen_k);
// setup batch_stride_* arguments
......@@ -503,7 +506,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
stride_o,
stride_randval,
stride_do,
stride_q, // stride_dq_acc
stride_dq_acc,
stride_q, // stride_dq
stride_dk,
stride_dv,
......@@ -516,7 +519,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
nhead_stride_randval,
nhead_stride_do,
nhead_stride_lsed,
nhead_stride_q, // nhead_stride_dq_acc
nhead_stride_dq_acc,
nhead_stride_q, // nhead_stride_dq
nhead_stride_k, // nhead_stride_dk
nhead_stride_v, // nhead_stride_dv
......
......@@ -440,6 +440,7 @@ struct fmha_bwd_traits
bool is_deterministic;
bool uses_ext_asm;
bool is_asm_atomic_fp32;
bool is_asm_no_coex;
// TODO: padding check is inside this api
};
float fmha_bwd(fmha_bwd_traits, fmha_bwd_args, const ck_tile::stream_config&);
This source diff could not be displayed because it is too large. You can view the blob instead.
This source diff could not be displayed because it is too large. You can view the blob instead.
This source diff could not be displayed because it is too large. You can view the blob instead.
This source diff could not be displayed because it is too large. You can view the blob instead.
......@@ -9,17 +9,19 @@ export CK_REPEAT=1
COMMON_ARGS='-v=1'
set -x
for prec in "fp16" "bf16" ; do
for perm in 1 ; do
for perm in 0 1 ; do
for hdim in 128 ; do
for asm_atomic_fp32 in 0 1 ; do
for asm_no_coex in 0 1 ; do
for mask in 0 1 ; do
$EXE -prec=$prec -b=4 -h=2 -d=$hdim -s=512 -iperm=$perm -operm=$perm -mask=$mask -ext_asm=1 -asm_atomic_fp32=$asm_atomic_fp32 -v=1 -mode=0 -kname=$KNAME $COMMON_ARGS
$EXE -prec=$prec -b=1 -h=3 -d=$hdim -s=768 -iperm=$perm -operm=$perm -mask=$mask -ext_asm=1 -asm_atomic_fp32=$asm_atomic_fp32 -v=1 -mode=0 -kname=$KNAME $COMMON_ARGS
$EXE -prec=$prec -b=4 -h=2 -d=$hdim -s=512 -iperm=$perm -operm=$perm -mask=$mask -ext_asm=1 -asm_atomic_fp32=$asm_atomic_fp32 -asm_no_coex=$asm_no_coex -v=1 -mode=0 -kname=$KNAME $COMMON_ARGS
$EXE -prec=$prec -b=1 -h=3 -d=$hdim -s=768 -iperm=$perm -operm=$perm -mask=$mask -ext_asm=1 -asm_atomic_fp32=$asm_atomic_fp32 -asm_no_coex=$asm_no_coex -v=1 -mode=0 -kname=$KNAME $COMMON_ARGS
done
done
done
done
done
done
set +x
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment