Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
d4de8495
"vscode:/vscode.git/clone" did not exist on "10575caca4f83223d1cda56327e3801d5eccd443"
Commit
d4de8495
authored
Oct 08, 2024
by
danyao12
Browse files
rename & ensure thread safety
parent
871c7556
Changes
12
Show whitespace changes
Inline
Side-by-side
Showing
12 changed files
with
146 additions
and
146 deletions
+146
-146
example/ck_tile/01_fmha/CMakeLists.txt
example/ck_tile/01_fmha/CMakeLists.txt
+1
-1
example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py
example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py
+96
-96
example/ck_tile/01_fmha/fmha_bwd.cpp
example/ck_tile/01_fmha/fmha_bwd.cpp
+19
-19
example/ck_tile/01_fmha/fmha_bwd.hpp
example/ck_tile/01_fmha/fmha_bwd.hpp
+4
-4
example/ck_tile/01_fmha/hsaco/bwd_bf16_spec_a32.cpp
example/ck_tile/01_fmha/hsaco/bwd_bf16_spec_a32.cpp
+1
-1
example/ck_tile/01_fmha/hsaco/bwd_bf16_spec_causal_a32.cpp
example/ck_tile/01_fmha/hsaco/bwd_bf16_spec_causal_a32.cpp
+1
-1
example/ck_tile/01_fmha/hsaco/bwd_fp16_spec_a32.cpp
example/ck_tile/01_fmha/hsaco/bwd_fp16_spec_a32.cpp
+1
-1
example/ck_tile/01_fmha/hsaco/bwd_fp16_spec_causal_a32.cpp
example/ck_tile/01_fmha/hsaco/bwd_fp16_spec_causal_a32.cpp
+1
-1
example/ck_tile/01_fmha/hsaco/fmha_hsaco.hpp
example/ck_tile/01_fmha/hsaco/fmha_hsaco.hpp
+4
-4
example/ck_tile/01_fmha/script/benchmark_bwd_ext.sh
example/ck_tile/01_fmha/script/benchmark_bwd_ext.sh
+12
-12
example/ck_tile/01_fmha/script/smoke_test_bwd_ext.sh
example/ck_tile/01_fmha/script/smoke_test_bwd_ext.sh
+4
-4
example/ck_tile/01_fmha/script/smoke_test_bwd_xqa_ext.sh
example/ck_tile/01_fmha/script/smoke_test_bwd_xqa_ext.sh
+2
-2
No files found.
example/ck_tile/01_fmha/CMakeLists.txt
View file @
d4de8495
...
@@ -67,7 +67,7 @@ set(EXAMPLE_FMHA_BWD "tile_example_fmha_bwd")
...
@@ -67,7 +67,7 @@ set(EXAMPLE_FMHA_BWD "tile_example_fmha_bwd")
# to be included in "make all/install/check"
# to be included in "make all/install/check"
message
(
"adding example
${
EXAMPLE_FMHA_BWD
}
"
)
message
(
"adding example
${
EXAMPLE_FMHA_BWD
}
"
)
add_executable
(
${
EXAMPLE_FMHA_BWD
}
EXCLUDE_FROM_ALL hsaco/bwd_bf16_a16.cpp hsaco/bwd_bf16_a16_rtz.cpp hsaco/bwd_bf16_a32.cpp hsaco/bwd_bf16_causal_a16.cpp hsaco/bwd_bf16_causal_a16_rtz.cpp hsaco/bwd_bf16_causal_a32.cpp hsaco/bwd_bf16_
nocoex
_a32.cpp hsaco/bwd_bf16_
nocoex
_causal_a32.cpp hsaco/bwd_fp16_a16.cpp hsaco/bwd_fp16_a32.cpp hsaco/bwd_fp16_causal_a16.cpp hsaco/bwd_fp16_causal_a32.cpp hsaco/bwd_fp16_
nocoex
_a32.cpp hsaco/bwd_fp16_
nocoex
_causal_a32.cpp fmha_bwd.cpp
)
add_executable
(
${
EXAMPLE_FMHA_BWD
}
EXCLUDE_FROM_ALL hsaco/bwd_bf16_a16.cpp hsaco/bwd_bf16_a16_rtz.cpp hsaco/bwd_bf16_a32.cpp hsaco/bwd_bf16_causal_a16.cpp hsaco/bwd_bf16_causal_a16_rtz.cpp hsaco/bwd_bf16_causal_a32.cpp hsaco/bwd_bf16_
spec
_a32.cpp hsaco/bwd_bf16_
spec
_causal_a32.cpp hsaco/bwd_fp16_a16.cpp hsaco/bwd_fp16_a32.cpp hsaco/bwd_fp16_causal_a16.cpp hsaco/bwd_fp16_causal_a32.cpp hsaco/bwd_fp16_
spec
_a32.cpp hsaco/bwd_fp16_
spec
_causal_a32.cpp fmha_bwd.cpp
)
target_include_directories
(
${
EXAMPLE_FMHA_BWD
}
PRIVATE
${
CMAKE_CURRENT_LIST_DIR
}
)
target_include_directories
(
${
EXAMPLE_FMHA_BWD
}
PRIVATE
${
CMAKE_CURRENT_LIST_DIR
}
)
target_sources
(
${
EXAMPLE_FMHA_BWD
}
PRIVATE
${
FMHA_BWD_GEN_BLOBS
}
)
target_sources
(
${
EXAMPLE_FMHA_BWD
}
PRIVATE
${
FMHA_BWD_GEN_BLOBS
}
)
...
...
example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py
View file @
d4de8495
...
@@ -188,7 +188,7 @@ struct p2
...
@@ -188,7 +188,7 @@ struct p2
unsigned int _p0;
unsigned int _p0;
unsigned int _p1;
unsigned int _p1;
}};
}};
struct __attribute__((packed)) fmha_bwd_
asm
_args
struct __attribute__((packed)) fmha_bwd_
v3
_args
{{
{{
void* ptr_dq;
void* ptr_dq;
p2 _p0;
p2 _p0;
...
@@ -224,7 +224,7 @@ struct __attribute__((packed)) fmha_bwd_asm_args
...
@@ -224,7 +224,7 @@ struct __attribute__((packed)) fmha_bwd_asm_args
p3 _p15;
p3 _p15;
}};
}};
struct __attribute__((packed)) fmha_bwd_xqa_
asm
_args
struct __attribute__((packed)) fmha_bwd_xqa_
v3
_args
{{
{{
void* ptr_dq;
void* ptr_dq;
p2 _p0;
p2 _p0;
...
@@ -270,7 +270,7 @@ struct __attribute__((packed)) fmha_bwd_xqa_asm_args
...
@@ -270,7 +270,7 @@ struct __attribute__((packed)) fmha_bwd_xqa_asm_args
p3 _p20;
p3 _p20;
}};
}};
struct fmha_bwd_
ext
_traits
struct fmha_bwd_
v3
_traits
{{
{{
int b;
int b;
int h;
int h;
...
@@ -283,17 +283,17 @@ struct fmha_bwd_ext_traits
...
@@ -283,17 +283,17 @@ struct fmha_bwd_ext_traits
int ts_kv;
int ts_kv;
}};
}};
class fmha_bwd_
ext
_kernel
class fmha_bwd_
v3
_kernel
{{
{{
public:
public:
fmha_bwd_
ext
_kernel(const std::string& name, unsigned char buffer[])
fmha_bwd_
v3
_kernel(const std::string& name, unsigned char buffer[])
{{
{{
HIP_CALL(hipModuleLoadData(&module, buffer));
HIP_CALL(hipModuleLoadData(&module, buffer));
HIP_CALL(hipModuleGetFunction(&kernel_func, module, name.c_str()));
HIP_CALL(hipModuleGetFunction(&kernel_func, module, name.c_str()));
}}
}}
void
void
launch_kernel(fmha_bwd_
ext
_traits fmha_
ext
_traits, fmha_bwd_
asm
_args args, const ck_tile::stream_config& s) const
launch_kernel(fmha_bwd_
v3
_traits fmha_
v3
_traits, fmha_bwd_
v3
_args args, const ck_tile::stream_config& s) const
{{
{{
size_t arg_size = sizeof(args);
size_t arg_size = sizeof(args);
void* config[] = {{HIP_LAUNCH_PARAM_BUFFER_POINTER,
void* config[] = {{HIP_LAUNCH_PARAM_BUFFER_POINTER,
...
@@ -303,12 +303,12 @@ class fmha_bwd_ext_kernel
...
@@ -303,12 +303,12 @@ class fmha_bwd_ext_kernel
HIP_LAUNCH_PARAM_END}};
HIP_LAUNCH_PARAM_END}};
int bdx = 256;
int bdx = 256;
int gdx = fmha_
ext
_traits.s / fmha_
ext
_traits.ts_kv;
int gdx = fmha_
v3
_traits.s / fmha_
v3
_traits.ts_kv;
int gdy = fmha_
ext
_traits.h;
int gdy = fmha_
v3
_traits.h;
int gdz = fmha_
ext
_traits.b;
int gdz = fmha_
v3
_traits.b;
if(fmha_
ext
_traits.mask > 0)
if(fmha_
v3
_traits.mask > 0)
{{
{{
int num_tg = fmha_
ext
_traits.s / fmha_
ext
_traits.ts_kv;
int num_tg = fmha_
v3
_traits.s / fmha_
v3
_traits.ts_kv;
gdx = (num_tg % 2) ? (num_tg / 2 + 1) : (num_tg / 2);
gdx = (num_tg % 2) ? (num_tg / 2 + 1) : (num_tg / 2);
}}
}}
HIP_CALL(hipModuleLaunchKernel(kernel_func,
HIP_CALL(hipModuleLaunchKernel(kernel_func,
...
@@ -325,7 +325,7 @@ class fmha_bwd_ext_kernel
...
@@ -325,7 +325,7 @@ class fmha_bwd_ext_kernel
}}
}}
void
void
launch_kernel(fmha_bwd_
ext
_traits fmha_
ext
_traits, fmha_bwd_xqa_
asm
_args args, const ck_tile::stream_config& s) const
launch_kernel(fmha_bwd_
v3
_traits fmha_
v3
_traits, fmha_bwd_xqa_
v3
_args args, const ck_tile::stream_config& s) const
{{
{{
size_t arg_size = sizeof(args);
size_t arg_size = sizeof(args);
void* config[] = {{HIP_LAUNCH_PARAM_BUFFER_POINTER,
void* config[] = {{HIP_LAUNCH_PARAM_BUFFER_POINTER,
...
@@ -335,12 +335,12 @@ class fmha_bwd_ext_kernel
...
@@ -335,12 +335,12 @@ class fmha_bwd_ext_kernel
HIP_LAUNCH_PARAM_END}};
HIP_LAUNCH_PARAM_END}};
int bdx = 256;
int bdx = 256;
int gdx = fmha_
ext
_traits.s / fmha_
ext
_traits.ts_kv;
int gdx = fmha_
v3
_traits.s / fmha_
v3
_traits.ts_kv;
int gdy = fmha_
ext
_traits.h;
int gdy = fmha_
v3
_traits.h;
int gdz = fmha_
ext
_traits.b;
int gdz = fmha_
v3
_traits.b;
if(fmha_
ext
_traits.mask > 0)
if(fmha_
v3
_traits.mask > 0)
{{
{{
int num_tg = fmha_
ext
_traits.s / fmha_
ext
_traits.ts_kv;
int num_tg = fmha_
v3
_traits.s / fmha_
v3
_traits.ts_kv;
gdx = (num_tg % 2) ? (num_tg / 2 + 1) : (num_tg / 2);
gdx = (num_tg % 2) ? (num_tg / 2 + 1) : (num_tg / 2);
}}
}}
HIP_CALL(hipModuleLaunchKernel(kernel_func,
HIP_CALL(hipModuleLaunchKernel(kernel_func,
...
@@ -374,11 +374,11 @@ float fmha_bwd_(const ck_tile::stream_config& s, fmha_bwd_args a)
...
@@ -374,11 +374,11 @@ float fmha_bwd_(const ck_tile::stream_config& s, fmha_bwd_args a)
}}
}}
template <typename dot_do_o_trait_>
template <typename dot_do_o_trait_>
float fmha_
ext_
bwd_(const ck_tile::stream_config& s, fmha_bwd_args a, unsigned char bwd_
ext_asm
[], const std::string& bwd_
ext
_name, bool io_perm)
float fmha_bwd
_v3
_(const ck_tile::stream_config& s, fmha_bwd_args a, unsigned char bwd_
v3_buf
[], const std::string& bwd_
v3
_name, bool io_perm)
{{
{{
if(s.log_level_ > 0)
if(s.log_level_ > 0)
std::cout << ", " << fmha_bwd_dot_do_o_get_name_<dot_do_o_trait_>() << ", " << bwd_
ext
_name << std::flush;
std::cout << ", " << fmha_bwd_dot_do_o_get_name_<dot_do_o_trait_>() << ", " << bwd_
v3
_name << std::flush;
fmha_bwd_
asm
_args args;
fmha_bwd_
v3
_args args;
args.ptr_dq = a.dq_ptr;
args.ptr_dq = a.dq_ptr;
args.ptr_dk = a.dk_ptr;
args.ptr_dk = a.dk_ptr;
args.ptr_dv = a.dv_ptr;
args.ptr_dv = a.dv_ptr;
...
@@ -406,7 +406,7 @@ float fmha_ext_bwd_(const ck_tile::stream_config& s, fmha_bwd_args a, unsigned
...
@@ -406,7 +406,7 @@ float fmha_ext_bwd_(const ck_tile::stream_config& s, fmha_bwd_args a, unsigned
args.Hs = stride_head;
args.Hs = stride_head;
args.BAs = stride_batch;
args.BAs = stride_batch;
args.Seqs = stride_seqlen;
args.Seqs = stride_seqlen;
auto traits = fmha_bwd_
ext
_traits{{a.batch,
auto traits = fmha_bwd_
v3
_traits{{a.batch,
a.nhead_q,
a.nhead_q,
a.seqlen_q,
a.seqlen_q,
a.hdim_q,
a.hdim_q,
...
@@ -414,7 +414,7 @@ float fmha_ext_bwd_(const ck_tile::stream_config& s, fmha_bwd_args a, unsigned
...
@@ -414,7 +414,7 @@ float fmha_ext_bwd_(const ck_tile::stream_config& s, fmha_bwd_args a, unsigned
a.mask_type,
a.mask_type,
32,
32,
128}};
128}};
fmha_bwd_
ext
_kernel impl(HSA_KERNEL, bwd_
ext_asm);
static
fmha_bwd_
v3
_kernel impl(HSA_KERNEL, bwd_
v3_buf); // static here is for thread safety.
return ck_tile::launch_kernel(s,
return ck_tile::launch_kernel(s,
[=](const ck_tile::stream_config& s_){{ fmha_bwd_dot_do_o_oneshot_<dot_do_o_trait_>(s_, a); }},
[=](const ck_tile::stream_config& s_){{ fmha_bwd_dot_do_o_oneshot_<dot_do_o_trait_>(s_, a); }},
[=](const ck_tile::stream_config& s_){{ impl.launch_kernel(traits, args, s_); }}
[=](const ck_tile::stream_config& s_){{ impl.launch_kernel(traits, args, s_); }}
...
@@ -422,11 +422,11 @@ float fmha_ext_bwd_(const ck_tile::stream_config& s, fmha_bwd_args a, unsigned
...
@@ -422,11 +422,11 @@ float fmha_ext_bwd_(const ck_tile::stream_config& s, fmha_bwd_args a, unsigned
}}
}}
template <typename dot_do_o_trait_>
template <typename dot_do_o_trait_>
float fmha_
ext_
bwd_xqa_(const ck_tile::stream_config& s, fmha_bwd_args a, unsigned char bwd_
ext_asm
[], const std::string& bwd_
ext
_name, bool io_perm)
float fmha_bwd
_v3
_xqa_(const ck_tile::stream_config& s, fmha_bwd_args a, unsigned char bwd_
v3_buf
[], const std::string& bwd_
v3
_name, bool io_perm)
{{
{{
if(s.log_level_ > 0)
if(s.log_level_ > 0)
std::cout << ", " << fmha_bwd_dot_do_o_get_name_<dot_do_o_trait_>() << ", " << bwd_
ext
_name << std::flush;
std::cout << ", " << fmha_bwd_dot_do_o_get_name_<dot_do_o_trait_>() << ", " << bwd_
v3
_name << std::flush;
fmha_bwd_xqa_
asm
_args args;
fmha_bwd_xqa_
v3
_args args;
args.ptr_dq = a.dq_ptr;
args.ptr_dq = a.dq_ptr;
args.ptr_dk = a.dk_ptr;
args.ptr_dk = a.dk_ptr;
args.ptr_dv = a.dv_ptr;
args.ptr_dv = a.dv_ptr;
...
@@ -469,7 +469,7 @@ float fmha_ext_bwd_xqa_(const ck_tile::stream_config& s, fmha_bwd_args a, unsig
...
@@ -469,7 +469,7 @@ float fmha_ext_bwd_xqa_(const ck_tile::stream_config& s, fmha_bwd_args a, unsig
args.BAs_kv = stride_batch_kv;
args.BAs_kv = stride_batch_kv;
args.Seqs_kv = stride_seqlen_kv;
args.Seqs_kv = stride_seqlen_kv;
args.Seqs_dkv = stride_seqlen_dkv;
args.Seqs_dkv = stride_seqlen_dkv;
auto traits = fmha_bwd_
ext
_traits{{a.batch,
auto traits = fmha_bwd_
v3
_traits{{a.batch,
a.nhead_q,
a.nhead_q,
a.seqlen_q,
a.seqlen_q,
a.hdim_q,
a.hdim_q,
...
@@ -477,7 +477,7 @@ float fmha_ext_bwd_xqa_(const ck_tile::stream_config& s, fmha_bwd_args a, unsig
...
@@ -477,7 +477,7 @@ float fmha_ext_bwd_xqa_(const ck_tile::stream_config& s, fmha_bwd_args a, unsig
a.mask_type,
a.mask_type,
32,
32,
128}};
128}};
fmha_bwd_
ext
_kernel impl(HSA_KERNEL, bwd_
ext_asm);
static
fmha_bwd_
v3
_kernel impl(HSA_KERNEL, bwd_
v3_buf); // static here is for thread safety.
return ck_tile::launch_kernel(s,
return ck_tile::launch_kernel(s,
[=](const ck_tile::stream_config& s_){{ fmha_bwd_dot_do_o_oneshot_<dot_do_o_trait_>(s_, a); }},
[=](const ck_tile::stream_config& s_){{ fmha_bwd_dot_do_o_oneshot_<dot_do_o_trait_>(s_, a); }},
[=](const ck_tile::stream_config& s_){{ impl.launch_kernel(traits, args, s_); }}
[=](const ck_tile::stream_config& s_){{ impl.launch_kernel(traits, args, s_); }}
...
@@ -485,11 +485,11 @@ float fmha_ext_bwd_xqa_(const ck_tile::stream_config& s, fmha_bwd_args a, unsig
...
@@ -485,11 +485,11 @@ float fmha_ext_bwd_xqa_(const ck_tile::stream_config& s, fmha_bwd_args a, unsig
}}
}}
template <typename dot_do_o_trait_, typename convert_dq_trait_>
template <typename dot_do_o_trait_, typename convert_dq_trait_>
float fmha_
ext_
bwd_(const ck_tile::stream_config& s, fmha_bwd_args a, unsigned char bwd_
ext_asm
[], const std::string& bwd_
ext
_name, bool io_perm)
float fmha_bwd
_v3
_(const ck_tile::stream_config& s, fmha_bwd_args a, unsigned char bwd_
v3_buf
[], const std::string& bwd_
v3
_name, bool io_perm)
{{
{{
if(s.log_level_ > 0)
if(s.log_level_ > 0)
std::cout << ", " << fmha_bwd_dot_do_o_get_name_<dot_do_o_trait_>() << ", " << bwd_
ext
_name << ", " << fmha_bwd_convert_dq_get_name_<convert_dq_trait_>() << std::flush;
std::cout << ", " << fmha_bwd_dot_do_o_get_name_<dot_do_o_trait_>() << ", " << bwd_
v3
_name << ", " << fmha_bwd_convert_dq_get_name_<convert_dq_trait_>() << std::flush;
fmha_bwd_
asm
_args args;
fmha_bwd_
v3
_args args;
args.ptr_dq = a.dq_acc_ptr;
args.ptr_dq = a.dq_acc_ptr;
args.ptr_dk = a.dk_ptr;
args.ptr_dk = a.dk_ptr;
args.ptr_dv = a.dv_ptr;
args.ptr_dv = a.dv_ptr;
...
@@ -517,7 +517,7 @@ float fmha_ext_bwd_(const ck_tile::stream_config& s, fmha_bwd_args a, unsigned
...
@@ -517,7 +517,7 @@ float fmha_ext_bwd_(const ck_tile::stream_config& s, fmha_bwd_args a, unsigned
args.Hs = stride_head;
args.Hs = stride_head;
args.BAs = stride_batch;
args.BAs = stride_batch;
args.Seqs = stride_seqlen;
args.Seqs = stride_seqlen;
auto traits = fmha_bwd_
ext
_traits{{a.batch,
auto traits = fmha_bwd_
v3
_traits{{a.batch,
a.nhead_q,
a.nhead_q,
a.seqlen_q,
a.seqlen_q,
a.hdim_q,
a.hdim_q,
...
@@ -525,7 +525,7 @@ float fmha_ext_bwd_(const ck_tile::stream_config& s, fmha_bwd_args a, unsigned
...
@@ -525,7 +525,7 @@ float fmha_ext_bwd_(const ck_tile::stream_config& s, fmha_bwd_args a, unsigned
a.mask_type,
a.mask_type,
32,
32,
128}};
128}};
fmha_bwd_
ext
_kernel impl(HSA_KERNEL, bwd_
ext_asm);
static
fmha_bwd_
v3
_kernel impl(HSA_KERNEL, bwd_
v3_buf); // static here is for thread safety.
return ck_tile::launch_kernel(s,
return ck_tile::launch_kernel(s,
[=](const ck_tile::stream_config& s_){{ fmha_bwd_dot_do_o_oneshot_<dot_do_o_trait_>(s_, a); }},
[=](const ck_tile::stream_config& s_){{ fmha_bwd_dot_do_o_oneshot_<dot_do_o_trait_>(s_, a); }},
[=](const ck_tile::stream_config& s_){{ impl.launch_kernel(traits, args, s_); }},
[=](const ck_tile::stream_config& s_){{ impl.launch_kernel(traits, args, s_); }},
...
@@ -536,139 +536,139 @@ float fmha_ext_bwd_(const ck_tile::stream_config& s, fmha_bwd_args a, unsigned
...
@@ -536,139 +536,139 @@ float fmha_ext_bwd_(const ck_tile::stream_config& s, fmha_bwd_args a, unsigned
float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config& s){{
float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config& s){{
float r = -1;
float r = -1;
if (t.uses_
ext_asm
== true){{
if (t.uses_
bwd_v3
== true){{
if ((t.is_group_mode == false) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) &&
if ((t.is_group_mode == false) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) &&
(a.seqlen_q == a.seqlen_k) && (a.seqlen_k % 128 == 0) && (a.hdim_q == 128) && (a.hdim_v == 128) && (t.is_deterministic == false) &&
(a.seqlen_q == a.seqlen_k) && (a.seqlen_k % 128 == 0) && (a.hdim_q == 128) && (a.hdim_v == 128) && (t.is_deterministic == false) &&
(a.stride_q == a.stride_o /*i_perm == o_perm*/)) {{
(a.stride_q == a.stride_o /*i_perm == o_perm*/)) {{
if(t.data_type.compare("fp16") == 0){{
if(t.data_type.compare("fp16") == 0){{
if(t.mask_type == mask_enum::no_mask){{
if(t.mask_type == mask_enum::no_mask){{
if((t.is_
asm
_atomic_fp32 == true) && (a.nhead_q == a.nhead_k /*MQA/GQA not supported yet*/)
if((t.is_
v3
_atomic_fp32 == true) && (a.nhead_q == a.nhead_k /*MQA/GQA not supported yet*/)
&& (a.nhead_stride_dq_acc > a.stride_dq_acc /*dq_acc only support BHSD*/)){{
&& (a.nhead_stride_dq_acc > a.stride_dq_acc /*dq_acc only support BHSD*/)){{
if(t.is_
asm_no_coex
== true){{
if(t.is_
v3_spec
== true){{
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, false, false>;
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, false, false>;
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, false, false, false>;
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, false, false, false>;
const std::string bwd_
ext
_name = "bwd_
ext
_fp16_
nocoex
_a32";
const std::string bwd_
v3
_name = "bwd_
v3
_fp16_
spec
_a32";
bool io_perm = a.nhead_stride_q > a.stride_q;
bool io_perm = a.nhead_stride_q > a.stride_q;
r = fmha_
ext_
bwd_<dot_do_o_trait_, convert_dq_trait_>(s, a, bwd_fp16_
nocoex
_a32, bwd_
ext
_name, io_perm);
r = fmha_bwd
_v3
_<dot_do_o_trait_, convert_dq_trait_>(s, a, bwd_fp16_
spec
_a32, bwd_
v3
_name, io_perm);
return r;
return r;
}}
}}
else{{
else{{
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, false, false>;
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, false, false>;
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, false, false, false>;
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, false, false, false>;
const std::string bwd_
ext
_name = "bwd_
ext
_fp16_a32";
const std::string bwd_
v3
_name = "bwd_
v3
_fp16_a32";
bool io_perm = a.nhead_stride_q > a.stride_q;
bool io_perm = a.nhead_stride_q > a.stride_q;
r = fmha_
ext_
bwd_<dot_do_o_trait_, convert_dq_trait_>(s, a, bwd_fp16_a32, bwd_
ext
_name, io_perm);
r = fmha_bwd
_v3
_<dot_do_o_trait_, convert_dq_trait_>(s, a, bwd_fp16_a32, bwd_
v3
_name, io_perm);
return r;
return r;
}}
}}
}}
}}
else if((t.is_
asm
_atomic_fp32 == false) && (a.nhead_q % a.nhead_k == 0)){{
else if((t.is_
v3
_atomic_fp32 == false) && (a.nhead_q % a.nhead_k == 0)){{
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, false, false>;
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, false, false>;
const std::string bwd_
ext
_name = "bwd_
ext
_fp16_a16";
const std::string bwd_
v3
_name = "bwd_
v3
_fp16_a16";
bool io_perm = a.nhead_stride_q > a.stride_q;
bool io_perm = a.nhead_stride_q > a.stride_q;
r = fmha_
ext_
bwd_xqa_<dot_do_o_trait_>(s, a, bwd_fp16_a16, bwd_
ext
_name, io_perm);
r = fmha_bwd
_v3
_xqa_<dot_do_o_trait_>(s, a, bwd_fp16_a16, bwd_
v3
_name, io_perm);
return r;
return r;
}}
}}
}}
}}
else if((t.mask_type != mask_enum::no_mask) && ((a.window_size_left == -1) && (a.window_size_right == 0))){{
else if((t.mask_type != mask_enum::no_mask) && ((a.window_size_left == -1) && (a.window_size_right == 0))){{
if((t.is_
asm
_atomic_fp32 == true) && (a.nhead_q == a.nhead_k /*MQA/GQA not supported yet*/)
if((t.is_
v3
_atomic_fp32 == true) && (a.nhead_q == a.nhead_k /*MQA/GQA not supported yet*/)
&& (a.nhead_stride_dq_acc > a.stride_dq_acc /*dq_acc only support BHSD*/)){{
&& (a.nhead_stride_dq_acc > a.stride_dq_acc /*dq_acc only support BHSD*/)){{
if(t.is_
asm_no_coex
== true){{
if(t.is_
v3_spec
== true){{
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, false, false>;
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, false, false>;
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, false, false, false>;
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, false, false, false>;
const std::string bwd_
ext
_name = "bwd_
ext
_fp16_
nocoex
_causal_a32";
const std::string bwd_
v3
_name = "bwd_
v3
_fp16_
spec
_causal_a32";
bool io_perm = a.nhead_stride_q > a.stride_q;
bool io_perm = a.nhead_stride_q > a.stride_q;
r = fmha_
ext_
bwd_<dot_do_o_trait_, convert_dq_trait_>(s, a, bwd_fp16_
nocoex
_causal_a32, bwd_
ext
_name, io_perm);
r = fmha_bwd
_v3
_<dot_do_o_trait_, convert_dq_trait_>(s, a, bwd_fp16_
spec
_causal_a32, bwd_
v3
_name, io_perm);
return r;
return r;
}}
}}
else{{
else{{
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, false, false>;
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, false, false>;
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, false, false, false>;
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::fp16_t, false, false, false, false>;
const std::string bwd_
ext
_name = "bwd_
ext
_fp16_causal_a32";
const std::string bwd_
v3
_name = "bwd_
v3
_fp16_causal_a32";
bool io_perm = a.nhead_stride_q > a.stride_q;
bool io_perm = a.nhead_stride_q > a.stride_q;
r = fmha_
ext_
bwd_<dot_do_o_trait_, convert_dq_trait_>(s, a, bwd_fp16_causal_a32, bwd_
ext
_name, io_perm);
r = fmha_bwd
_v3
_<dot_do_o_trait_, convert_dq_trait_>(s, a, bwd_fp16_causal_a32, bwd_
v3
_name, io_perm);
return r;
return r;
}}
}}
}}
}}
else if((t.is_
asm
_atomic_fp32 == false) && (a.nhead_q % a.nhead_k == 0)){{
else if((t.is_
v3
_atomic_fp32 == false) && (a.nhead_q % a.nhead_k == 0)){{
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, false, false>;
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::fp16_t, false, false, false>;
const std::string bwd_
ext
_name = "bwd_
ext
_fp16_causal_a16";
const std::string bwd_
v3
_name = "bwd_
v3
_fp16_causal_a16";
bool io_perm = a.nhead_stride_q > a.stride_q;
bool io_perm = a.nhead_stride_q > a.stride_q;
r = fmha_
ext_
bwd_xqa_<dot_do_o_trait_>(s, a, bwd_fp16_causal_a16, bwd_
ext
_name, io_perm);
r = fmha_bwd
_v3
_xqa_<dot_do_o_trait_>(s, a, bwd_fp16_causal_a16, bwd_
v3
_name, io_perm);
return r;
return r;
}}
}}
}}
}}
}}
}}
else if(t.data_type.compare("bf16") == 0){{
else if(t.data_type.compare("bf16") == 0){{
if(t.mask_type == mask_enum::no_mask){{
if(t.mask_type == mask_enum::no_mask){{
if((t.is_
asm
_atomic_fp32 == true) && (a.nhead_q == a.nhead_k /*MQA/GQA not supported yet*/)
if((t.is_
v3
_atomic_fp32 == true) && (a.nhead_q == a.nhead_k /*MQA/GQA not supported yet*/)
&& (a.nhead_stride_dq_acc > a.stride_dq_acc /*dq_acc only support BHSD*/)){{
&& (a.nhead_stride_dq_acc > a.stride_dq_acc /*dq_acc only support BHSD*/)){{
if(t.is_
asm_no_coex
== true){{
if(t.is_
v3_spec
== true){{
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, false, false>;
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, false, false>;
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, false, false, false>;
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, false, false, false>;
const std::string bwd_
ext
_name = "bwd_
ext
_bf16_
nocoex
_a32";
const std::string bwd_
v3
_name = "bwd_
v3
_bf16_
spec
_a32";
bool io_perm = a.nhead_stride_q > a.stride_q;
bool io_perm = a.nhead_stride_q > a.stride_q;
r = fmha_
ext_
bwd_<dot_do_o_trait_, convert_dq_trait_>(s, a, bwd_bf16_
nocoex
_a32, bwd_
ext
_name, io_perm);
r = fmha_bwd
_v3
_<dot_do_o_trait_, convert_dq_trait_>(s, a, bwd_bf16_
spec
_a32, bwd_
v3
_name, io_perm);
return r;
return r;
}}
}}
else{{
else{{
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, false, false>;
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, false, false>;
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, false, false, false>;
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, false, false, false>;
const std::string bwd_
ext
_name = "bwd_
ext
_bf16_a32";
const std::string bwd_
v3
_name = "bwd_
v3
_bf16_a32";
bool io_perm = a.nhead_stride_q > a.stride_q;
bool io_perm = a.nhead_stride_q > a.stride_q;
r = fmha_
ext_
bwd_<dot_do_o_trait_, convert_dq_trait_>(s, a, bwd_bf16_a32, bwd_
ext
_name, io_perm);
r = fmha_bwd
_v3
_<dot_do_o_trait_, convert_dq_trait_>(s, a, bwd_bf16_a32, bwd_
v3
_name, io_perm);
return r;
return r;
}}
}}
}}
}}
else if((t.is_
asm
_atomic_fp32 == false) && (a.nhead_q % a.nhead_k == 0)){{
else if((t.is_
v3
_atomic_fp32 == false) && (a.nhead_q % a.nhead_k == 0)){{
if(t.is_
asm
_rtz_cvt == true){{
if(t.is_
v3
_rtz_cvt == true){{
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, false, false>;
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, false, false>;
const std::string bwd_
ext
_name = "bwd_
ext
_bf16_a16_rtz";
const std::string bwd_
v3
_name = "bwd_
v3
_bf16_a16_rtz";
bool io_perm = a.nhead_stride_q > a.stride_q;
bool io_perm = a.nhead_stride_q > a.stride_q;
r = fmha_
ext_
bwd_xqa_<dot_do_o_trait_>(s, a, bwd_bf16_a16_rtz, bwd_
ext
_name, io_perm);
r = fmha_bwd
_v3
_xqa_<dot_do_o_trait_>(s, a, bwd_bf16_a16_rtz, bwd_
v3
_name, io_perm);
return r;
return r;
}}
}}
else{{
else{{
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, false, false>;
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, false, false>;
const std::string bwd_
ext
_name = "bwd_
ext
_bf16_a16";
const std::string bwd_
v3
_name = "bwd_
v3
_bf16_a16";
bool io_perm = a.nhead_stride_q > a.stride_q;
bool io_perm = a.nhead_stride_q > a.stride_q;
r = fmha_
ext_
bwd_xqa_<dot_do_o_trait_>(s, a, bwd_bf16_a16, bwd_
ext
_name, io_perm);
r = fmha_bwd
_v3
_xqa_<dot_do_o_trait_>(s, a, bwd_bf16_a16, bwd_
v3
_name, io_perm);
return r;
return r;
}}
}}
}}
}}
}}
}}
else if((t.mask_type != mask_enum::no_mask) && ((a.window_size_left == -1) && (a.window_size_right == 0))){{
else if((t.mask_type != mask_enum::no_mask) && ((a.window_size_left == -1) && (a.window_size_right == 0))){{
if((t.is_
asm
_atomic_fp32 == true) && (a.nhead_q == a.nhead_k /*MQA/GQA not supported yet*/)
if((t.is_
v3
_atomic_fp32 == true) && (a.nhead_q == a.nhead_k /*MQA/GQA not supported yet*/)
&& (a.nhead_stride_dq_acc > a.stride_dq_acc /*dq_acc only support BHSD*/)){{
&& (a.nhead_stride_dq_acc > a.stride_dq_acc /*dq_acc only support BHSD*/)){{
if(t.is_
asm_no_coex
== true){{
if(t.is_
v3_spec
== true){{
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, false, false>;
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, false, false>;
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, false, false, false>;
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, false, false, false>;
const std::string bwd_
ext
_name = "bwd_
ext
_bf16_
nocoex
_causal_a32";
const std::string bwd_
v3
_name = "bwd_
v3
_bf16_
spec
_causal_a32";
bool io_perm = a.nhead_stride_q > a.stride_q;
bool io_perm = a.nhead_stride_q > a.stride_q;
r = fmha_
ext_
bwd_<dot_do_o_trait_, convert_dq_trait_>(s, a, bwd_bf16_
nocoex
_causal_a32, bwd_
ext
_name, io_perm);
r = fmha_bwd
_v3
_<dot_do_o_trait_, convert_dq_trait_>(s, a, bwd_bf16_
spec
_causal_a32, bwd_
v3
_name, io_perm);
return r;
return r;
}}
}}
else{{
else{{
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, false, false>;
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, false, false>;
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, false, false, false>;
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, ck_tile::bf16_t, false, false, false, false>;
const std::string bwd_
ext
_name = "bwd_
ext
_bf16_causal_a32";
const std::string bwd_
v3
_name = "bwd_
v3
_bf16_causal_a32";
bool io_perm = a.nhead_stride_q > a.stride_q;
bool io_perm = a.nhead_stride_q > a.stride_q;
r = fmha_
ext_
bwd_<dot_do_o_trait_, convert_dq_trait_>(s, a, bwd_bf16_causal_a32, bwd_
ext
_name, io_perm);
r = fmha_bwd
_v3
_<dot_do_o_trait_, convert_dq_trait_>(s, a, bwd_bf16_causal_a32, bwd_
v3
_name, io_perm);
return r;
return r;
}}
}}
}}
}}
else if((t.is_
asm
_atomic_fp32 == false) && (a.nhead_q % a.nhead_k == 0)){{
else if((t.is_
v3
_atomic_fp32 == false) && (a.nhead_q % a.nhead_k == 0)){{
if(t.is_
asm
_rtz_cvt == true){{
if(t.is_
v3
_rtz_cvt == true){{
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, false, false>;
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, false, false>;
const std::string bwd_
ext
_name = "bwd_
ext
_bf16_causal_a16_rtz";
const std::string bwd_
v3
_name = "bwd_
v3
_bf16_causal_a16_rtz";
bool io_perm = a.nhead_stride_q > a.stride_q;
bool io_perm = a.nhead_stride_q > a.stride_q;
r = fmha_
ext_
bwd_xqa_<dot_do_o_trait_>(s, a, bwd_bf16_causal_a16_rtz, bwd_
ext
_name, io_perm);
r = fmha_bwd
_v3
_xqa_<dot_do_o_trait_>(s, a, bwd_bf16_causal_a16_rtz, bwd_
v3
_name, io_perm);
return r;
return r;
}}
}}
else{{
else{{
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, false, false>;
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, ck_tile::bf16_t, false, false, false>;
const std::string bwd_
ext
_name = "bwd_
ext
_bf16_causal_a16";
const std::string bwd_
v3
_name = "bwd_
v3
_bf16_causal_a16";
bool io_perm = a.nhead_stride_q > a.stride_q;
bool io_perm = a.nhead_stride_q > a.stride_q;
r = fmha_
ext_
bwd_xqa_<dot_do_o_trait_>(s, a, bwd_bf16_causal_a16, bwd_
ext
_name, io_perm);
r = fmha_bwd
_v3
_xqa_<dot_do_o_trait_>(s, a, bwd_bf16_causal_a16, bwd_
v3
_name, io_perm);
return r;
return r;
}}
}}
}}
}}
...
...
example/ck_tile/01_fmha/fmha_bwd.cpp
View file @
d4de8495
...
@@ -92,17 +92,17 @@ auto create_args(int argc, char* argv[])
...
@@ -92,17 +92,17 @@ auto create_args(int argc, char* argv[])
"0"
,
"0"
,
"if set to 1 will use multi-buffer reduction strategy for dq, atomic opeartion "
"if set to 1 will use multi-buffer reduction strategy for dq, atomic opeartion "
"will not be used"
)
"will not be used"
)
.
insert
(
"
ext_asm
"
,
"0"
,
"if set to 1, some cases will call the
ext asm
dqdkdv kernel"
)
.
insert
(
"
bwd_v3
"
,
"0"
,
"if set to 1, some cases will call the
bwd v3
dqdkdv kernel"
)
.
insert
(
.
insert
(
"
asm
_atomic_fp32"
,
"
v3
_atomic_fp32"
,
"1"
,
"1"
,
"if set to 0 will use atomic fp16/bf16(w/o convert_dq kernel) when
ext_asm
is set to 1"
)
"if set to 0 will use atomic fp16/bf16(w/o convert_dq kernel) when
bwd_v3
is set to 1"
)
.
insert
(
"
asm_no_coex
"
,
.
insert
(
"
v3_spec
"
,
"0"
,
"0"
,
"if set to 1 will
use non-coexectuion
kernel when
ext_asm
is set to 1"
)
"if set to 1 will
call the specialized v3
kernel when
bwd_v3
is set to 1"
)
.
insert
(
"
asm
_rtz_cvt"
,
.
insert
(
"
v3
_rtz_cvt"
,
"0"
,
"0"
,
"if set to 1 will use float to bf16 RTZ convert when
ext_asm
is set to 1"
);
"if set to 1 will use float to bf16 RTZ convert when
bwd_v3
is set to 1"
);
bool
result
=
arg_parser
.
parse
(
argc
,
argv
);
bool
result
=
arg_parser
.
parse
(
argc
,
argv
);
return
std
::
make_tuple
(
result
,
arg_parser
);
return
std
::
make_tuple
(
result
,
arg_parser
);
...
@@ -191,10 +191,10 @@ bool run(const ck_tile::ArgParser& arg_parser)
...
@@ -191,10 +191,10 @@ bool run(const ck_tile::ArgParser& arg_parser)
int
stream_repeat
=
arg_parser
.
get_int
(
"repeat"
);
int
stream_repeat
=
arg_parser
.
get_int
(
"repeat"
);
bool
kname
=
arg_parser
.
get_bool
(
"kname"
);
bool
kname
=
arg_parser
.
get_bool
(
"kname"
);
bool
deterministic
=
arg_parser
.
get_bool
(
"deterministic"
);
bool
deterministic
=
arg_parser
.
get_bool
(
"deterministic"
);
bool
ext_asm
=
arg_parser
.
get_bool
(
"
ext_asm
"
);
bool
bwd_v3
=
arg_parser
.
get_bool
(
"
bwd_v3
"
);
bool
asm
_atomic_fp32
=
arg_parser
.
get_bool
(
"
asm
_atomic_fp32"
);
bool
v3
_atomic_fp32
=
arg_parser
.
get_bool
(
"
v3
_atomic_fp32"
);
bool
asm_no_coex
=
arg_parser
.
get_bool
(
"
asm_no_coex
"
);
bool
v3_spec
=
arg_parser
.
get_bool
(
"
v3_spec
"
);
bool
asm
_rtz_cvt
=
arg_parser
.
get_bool
(
"
asm
_rtz_cvt"
);
bool
v3
_rtz_cvt
=
arg_parser
.
get_bool
(
"
v3
_rtz_cvt"
);
ck_tile
::
stream_config
stream_config
{
nullptr
,
ck_tile
::
stream_config
stream_config
{
nullptr
,
true
,
true
,
...
@@ -430,10 +430,10 @@ bool run(const ck_tile::ArgParser& arg_parser)
...
@@ -430,10 +430,10 @@ bool run(const ck_tile::ArgParser& arg_parser)
p_drop
>
0.0
f
,
p_drop
>
0.0
f
,
s_randval
,
s_randval
,
deterministic
,
deterministic
,
ext_asm
,
bwd_v3
,
asm
_atomic_fp32
,
v3
_atomic_fp32
,
asm_no_coex
,
v3_spec
,
asm
_rtz_cvt
};
v3
_rtz_cvt
};
auto
fmha_args
=
[
&
]()
{
auto
fmha_args
=
[
&
]()
{
assert
(
nhead
%
nhead_k
==
0
);
assert
(
nhead
%
nhead_k
==
0
);
/// NOTE: we broadcast bias from [1, 1, seqlen_q, seqlen_k] to [batch, nhead, seqlen_q,
/// NOTE: we broadcast bias from [1, 1, seqlen_q, seqlen_k] to [batch, nhead, seqlen_q,
...
...
example/ck_tile/01_fmha/fmha_bwd.hpp
View file @
d4de8495
...
@@ -438,10 +438,10 @@ struct fmha_bwd_traits
...
@@ -438,10 +438,10 @@ struct fmha_bwd_traits
bool
has_dropout
;
bool
has_dropout
;
bool
is_store_randval
;
bool
is_store_randval
;
bool
is_deterministic
;
bool
is_deterministic
;
bool
uses_
ext_asm
;
bool
uses_
bwd_v3
;
bool
is_
asm
_atomic_fp32
;
bool
is_
v3
_atomic_fp32
;
bool
is_
asm_no_coex
;
bool
is_
v3_spec
;
bool
is_
asm
_rtz_cvt
;
bool
is_
v3
_rtz_cvt
;
// TODO: padding check is inside this api
// TODO: padding check is inside this api
};
};
float
fmha_bwd
(
fmha_bwd_traits
,
fmha_bwd_args
,
const
ck_tile
::
stream_config
&
);
float
fmha_bwd
(
fmha_bwd_traits
,
fmha_bwd_args
,
const
ck_tile
::
stream_config
&
);
example/ck_tile/01_fmha/hsaco/bwd_bf16_
nocoex
_a32.cpp
→
example/ck_tile/01_fmha/hsaco/bwd_bf16_
spec
_a32.cpp
View file @
d4de8495
...
@@ -2,7 +2,7 @@
...
@@ -2,7 +2,7 @@
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "fmha_hsaco.hpp"
#include "fmha_hsaco.hpp"
unsigned
char
bwd_bf16_
nocoex
_a32
[]
=
{
unsigned
char
bwd_bf16_
spec
_a32
[]
=
{
0x7F
,
0x45
,
0x4C
,
0x46
,
0x02
,
0x01
,
0x01
,
0x40
,
0x03
,
0x00
,
0x00
,
0x00
,
0x00
,
0x00
,
0x00
,
0x00
,
0x7F
,
0x45
,
0x4C
,
0x46
,
0x02
,
0x01
,
0x01
,
0x40
,
0x03
,
0x00
,
0x00
,
0x00
,
0x00
,
0x00
,
0x00
,
0x00
,
0x03
,
0x00
,
0xE0
,
0x00
,
0x01
,
0x00
,
0x00
,
0x00
,
0x00
,
0x00
,
0x00
,
0x00
,
0x00
,
0x00
,
0x00
,
0x00
,
0x03
,
0x00
,
0xE0
,
0x00
,
0x01
,
0x00
,
0x00
,
0x00
,
0x00
,
0x00
,
0x00
,
0x00
,
0x00
,
0x00
,
0x00
,
0x00
,
0x40
,
0x00
,
0x00
,
0x00
,
0x00
,
0x00
,
0x00
,
0x00
,
0xB0
,
0x7D
,
0x00
,
0x00
,
0x00
,
0x00
,
0x00
,
0x00
,
0x40
,
0x00
,
0x00
,
0x00
,
0x00
,
0x00
,
0x00
,
0x00
,
0xB0
,
0x7D
,
0x00
,
0x00
,
0x00
,
0x00
,
0x00
,
0x00
,
...
...
example/ck_tile/01_fmha/hsaco/bwd_bf16_
nocoex
_causal_a32.cpp
→
example/ck_tile/01_fmha/hsaco/bwd_bf16_
spec
_causal_a32.cpp
View file @
d4de8495
...
@@ -2,7 +2,7 @@
...
@@ -2,7 +2,7 @@
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "fmha_hsaco.hpp"
#include "fmha_hsaco.hpp"
unsigned
char
bwd_bf16_
nocoex
_causal_a32
[]
=
{
unsigned
char
bwd_bf16_
spec
_causal_a32
[]
=
{
0x7F
,
0x45
,
0x4C
,
0x46
,
0x02
,
0x01
,
0x01
,
0x40
,
0x03
,
0x00
,
0x00
,
0x00
,
0x00
,
0x00
,
0x00
,
0x00
,
0x7F
,
0x45
,
0x4C
,
0x46
,
0x02
,
0x01
,
0x01
,
0x40
,
0x03
,
0x00
,
0x00
,
0x00
,
0x00
,
0x00
,
0x00
,
0x00
,
0x03
,
0x00
,
0xE0
,
0x00
,
0x01
,
0x00
,
0x00
,
0x00
,
0x00
,
0x00
,
0x00
,
0x00
,
0x00
,
0x00
,
0x00
,
0x00
,
0x03
,
0x00
,
0xE0
,
0x00
,
0x01
,
0x00
,
0x00
,
0x00
,
0x00
,
0x00
,
0x00
,
0x00
,
0x00
,
0x00
,
0x00
,
0x00
,
0x40
,
0x00
,
0x00
,
0x00
,
0x00
,
0x00
,
0x00
,
0x00
,
0x08
,
0x85
,
0x00
,
0x00
,
0x00
,
0x00
,
0x00
,
0x00
,
0x40
,
0x00
,
0x00
,
0x00
,
0x00
,
0x00
,
0x00
,
0x00
,
0x08
,
0x85
,
0x00
,
0x00
,
0x00
,
0x00
,
0x00
,
0x00
,
...
...
example/ck_tile/01_fmha/hsaco/bwd_fp16_
nocoex
_a32.cpp
→
example/ck_tile/01_fmha/hsaco/bwd_fp16_
spec
_a32.cpp
View file @
d4de8495
...
@@ -2,7 +2,7 @@
...
@@ -2,7 +2,7 @@
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "fmha_hsaco.hpp"
#include "fmha_hsaco.hpp"
unsigned
char
bwd_fp16_
nocoex
_a32
[]
=
{
unsigned
char
bwd_fp16_
spec
_a32
[]
=
{
0x7F
,
0x45
,
0x4C
,
0x46
,
0x02
,
0x01
,
0x01
,
0x40
,
0x03
,
0x00
,
0x00
,
0x00
,
0x00
,
0x00
,
0x00
,
0x00
,
0x7F
,
0x45
,
0x4C
,
0x46
,
0x02
,
0x01
,
0x01
,
0x40
,
0x03
,
0x00
,
0x00
,
0x00
,
0x00
,
0x00
,
0x00
,
0x00
,
0x03
,
0x00
,
0xE0
,
0x00
,
0x01
,
0x00
,
0x00
,
0x00
,
0x00
,
0x00
,
0x00
,
0x00
,
0x00
,
0x00
,
0x00
,
0x00
,
0x03
,
0x00
,
0xE0
,
0x00
,
0x01
,
0x00
,
0x00
,
0x00
,
0x00
,
0x00
,
0x00
,
0x00
,
0x00
,
0x00
,
0x00
,
0x00
,
0x40
,
0x00
,
0x00
,
0x00
,
0x00
,
0x00
,
0x00
,
0x00
,
0x98
,
0x5B
,
0x00
,
0x00
,
0x00
,
0x00
,
0x00
,
0x00
,
0x40
,
0x00
,
0x00
,
0x00
,
0x00
,
0x00
,
0x00
,
0x00
,
0x98
,
0x5B
,
0x00
,
0x00
,
0x00
,
0x00
,
0x00
,
0x00
,
...
...
example/ck_tile/01_fmha/hsaco/bwd_fp16_
nocoex
_causal_a32.cpp
→
example/ck_tile/01_fmha/hsaco/bwd_fp16_
spec
_causal_a32.cpp
View file @
d4de8495
...
@@ -2,7 +2,7 @@
...
@@ -2,7 +2,7 @@
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "fmha_hsaco.hpp"
#include "fmha_hsaco.hpp"
unsigned
char
bwd_fp16_
nocoex
_causal_a32
[]
=
{
unsigned
char
bwd_fp16_
spec
_causal_a32
[]
=
{
0x7F
,
0x45
,
0x4C
,
0x46
,
0x02
,
0x01
,
0x01
,
0x40
,
0x03
,
0x00
,
0x00
,
0x00
,
0x00
,
0x00
,
0x00
,
0x00
,
0x7F
,
0x45
,
0x4C
,
0x46
,
0x02
,
0x01
,
0x01
,
0x40
,
0x03
,
0x00
,
0x00
,
0x00
,
0x00
,
0x00
,
0x00
,
0x00
,
0x03
,
0x00
,
0xE0
,
0x00
,
0x01
,
0x00
,
0x00
,
0x00
,
0x00
,
0x00
,
0x00
,
0x00
,
0x00
,
0x00
,
0x00
,
0x00
,
0x03
,
0x00
,
0xE0
,
0x00
,
0x01
,
0x00
,
0x00
,
0x00
,
0x00
,
0x00
,
0x00
,
0x00
,
0x00
,
0x00
,
0x00
,
0x00
,
0x40
,
0x00
,
0x00
,
0x00
,
0x00
,
0x00
,
0x00
,
0x00
,
0xF0
,
0x62
,
0x00
,
0x00
,
0x00
,
0x00
,
0x00
,
0x00
,
0x40
,
0x00
,
0x00
,
0x00
,
0x00
,
0x00
,
0x00
,
0x00
,
0xF0
,
0x62
,
0x00
,
0x00
,
0x00
,
0x00
,
0x00
,
0x00
,
...
...
example/ck_tile/01_fmha/hsaco/fmha_hsaco.hpp
View file @
d4de8495
...
@@ -9,11 +9,11 @@ extern unsigned char bwd_bf16_a32[];
...
@@ -9,11 +9,11 @@ extern unsigned char bwd_bf16_a32[];
extern
unsigned
char
bwd_bf16_causal_a16
[];
extern
unsigned
char
bwd_bf16_causal_a16
[];
extern
unsigned
char
bwd_bf16_causal_a16_rtz
[];
extern
unsigned
char
bwd_bf16_causal_a16_rtz
[];
extern
unsigned
char
bwd_bf16_causal_a32
[];
extern
unsigned
char
bwd_bf16_causal_a32
[];
extern
unsigned
char
bwd_bf16_
nocoex
_a32
[];
extern
unsigned
char
bwd_bf16_
spec
_a32
[];
extern
unsigned
char
bwd_bf16_
nocoex
_causal_a32
[];
extern
unsigned
char
bwd_bf16_
spec
_causal_a32
[];
extern
unsigned
char
bwd_fp16_a16
[];
extern
unsigned
char
bwd_fp16_a16
[];
extern
unsigned
char
bwd_fp16_a32
[];
extern
unsigned
char
bwd_fp16_a32
[];
extern
unsigned
char
bwd_fp16_causal_a16
[];
extern
unsigned
char
bwd_fp16_causal_a16
[];
extern
unsigned
char
bwd_fp16_causal_a32
[];
extern
unsigned
char
bwd_fp16_causal_a32
[];
extern
unsigned
char
bwd_fp16_
nocoex
_a32
[];
extern
unsigned
char
bwd_fp16_
spec
_a32
[];
extern
unsigned
char
bwd_fp16_
nocoex
_causal_a32
[];
extern
unsigned
char
bwd_fp16_
spec
_causal_a32
[];
example/ck_tile/01_fmha/script/benchmark_bwd_ext.sh
View file @
d4de8495
...
@@ -9,23 +9,23 @@ for hdim in 128 ; do
...
@@ -9,23 +9,23 @@ for hdim in 128 ; do
nhead
=
$((
2048
/
$hdim
))
# follow fav2 setup
nhead
=
$((
2048
/
$hdim
))
# follow fav2 setup
$EXE
-prec
=
$prec
-b
=
32
-h
=
$nhead
-d
=
$hdim
-s
=
512
-iperm
=
$perm
-operm
=
$perm
-kname
=
1
-v
=
$VALID
;
sleep
3
$EXE
-prec
=
$prec
-b
=
32
-h
=
$nhead
-d
=
$hdim
-s
=
512
-iperm
=
$perm
-operm
=
$perm
-kname
=
1
-v
=
$VALID
;
sleep
3
$EXE
-prec
=
$prec
-b
=
32
-h
=
$nhead
-d
=
$hdim
-s
=
512
-iperm
=
$perm
-operm
=
$perm
-kname
=
1
-
ext_asm
=
1
-v
=
$VALID
;
sleep
3
$EXE
-prec
=
$prec
-b
=
32
-h
=
$nhead
-d
=
$hdim
-s
=
512
-iperm
=
$perm
-operm
=
$perm
-kname
=
1
-
bwd_v3
=
1
-v
=
$VALID
;
sleep
3
$EXE
-prec
=
$prec
-b
=
32
-h
=
$nhead
-d
=
$hdim
-s
=
512
-iperm
=
$perm
-operm
=
$perm
-kname
=
1
-
ext_asm
=
1
-
asm
_atomic_fp32
=
0
-v
=
$VALID
;
sleep
3
$EXE
-prec
=
$prec
-b
=
32
-h
=
$nhead
-d
=
$hdim
-s
=
512
-iperm
=
$perm
-operm
=
$perm
-kname
=
1
-
bwd_v3
=
1
-
v3
_atomic_fp32
=
0
-v
=
$VALID
;
sleep
3
$EXE
-prec
=
$prec
-b
=
16
-h
=
$nhead
-d
=
$hdim
-s
=
1024
-iperm
=
$perm
-operm
=
$perm
-kname
=
1
-v
=
$VALID
;
sleep
3
$EXE
-prec
=
$prec
-b
=
16
-h
=
$nhead
-d
=
$hdim
-s
=
1024
-iperm
=
$perm
-operm
=
$perm
-kname
=
1
-v
=
$VALID
;
sleep
3
$EXE
-prec
=
$prec
-b
=
16
-h
=
$nhead
-d
=
$hdim
-s
=
1024
-iperm
=
$perm
-operm
=
$perm
-kname
=
1
-
ext_asm
=
1
-v
=
$VALID
;
sleep
3
$EXE
-prec
=
$prec
-b
=
16
-h
=
$nhead
-d
=
$hdim
-s
=
1024
-iperm
=
$perm
-operm
=
$perm
-kname
=
1
-
bwd_v3
=
1
-v
=
$VALID
;
sleep
3
$EXE
-prec
=
$prec
-b
=
16
-h
=
$nhead
-d
=
$hdim
-s
=
1024
-iperm
=
$perm
-operm
=
$perm
-kname
=
1
-
ext_asm
=
1
-
asm
_atomic_fp32
=
0
-v
=
$VALID
;
sleep
3
$EXE
-prec
=
$prec
-b
=
16
-h
=
$nhead
-d
=
$hdim
-s
=
1024
-iperm
=
$perm
-operm
=
$perm
-kname
=
1
-
bwd_v3
=
1
-
v3
_atomic_fp32
=
0
-v
=
$VALID
;
sleep
3
$EXE
-prec
=
$prec
-b
=
8
-h
=
$nhead
-d
=
$hdim
-s
=
2048
-iperm
=
$perm
-operm
=
$perm
-kname
=
1
-v
=
$VALID
;
sleep
3
$EXE
-prec
=
$prec
-b
=
8
-h
=
$nhead
-d
=
$hdim
-s
=
2048
-iperm
=
$perm
-operm
=
$perm
-kname
=
1
-v
=
$VALID
;
sleep
3
$EXE
-prec
=
$prec
-b
=
8
-h
=
$nhead
-d
=
$hdim
-s
=
2048
-iperm
=
$perm
-operm
=
$perm
-kname
=
1
-
ext_asm
=
1
-v
=
$VALID
;
sleep
3
$EXE
-prec
=
$prec
-b
=
8
-h
=
$nhead
-d
=
$hdim
-s
=
2048
-iperm
=
$perm
-operm
=
$perm
-kname
=
1
-
bwd_v3
=
1
-v
=
$VALID
;
sleep
3
$EXE
-prec
=
$prec
-b
=
8
-h
=
$nhead
-d
=
$hdim
-s
=
2048
-iperm
=
$perm
-operm
=
$perm
-kname
=
1
-
ext_asm
=
1
-
asm
_atomic_fp32
=
0
-v
=
$VALID
;
sleep
3
$EXE
-prec
=
$prec
-b
=
8
-h
=
$nhead
-d
=
$hdim
-s
=
2048
-iperm
=
$perm
-operm
=
$perm
-kname
=
1
-
bwd_v3
=
1
-
v3
_atomic_fp32
=
0
-v
=
$VALID
;
sleep
3
$EXE
-prec
=
$prec
-b
=
4
-h
=
$nhead
-d
=
$hdim
-s
=
4096
-iperm
=
$perm
-operm
=
$perm
-kname
=
1
-v
=
$VALID
;
sleep
3
$EXE
-prec
=
$prec
-b
=
4
-h
=
$nhead
-d
=
$hdim
-s
=
4096
-iperm
=
$perm
-operm
=
$perm
-kname
=
1
-v
=
$VALID
;
sleep
3
$EXE
-prec
=
$prec
-b
=
4
-h
=
$nhead
-d
=
$hdim
-s
=
4096
-iperm
=
$perm
-operm
=
$perm
-kname
=
1
-
ext_asm
=
1
-v
=
$VALID
;
sleep
3
$EXE
-prec
=
$prec
-b
=
4
-h
=
$nhead
-d
=
$hdim
-s
=
4096
-iperm
=
$perm
-operm
=
$perm
-kname
=
1
-
bwd_v3
=
1
-v
=
$VALID
;
sleep
3
$EXE
-prec
=
$prec
-b
=
4
-h
=
$nhead
-d
=
$hdim
-s
=
4096
-iperm
=
$perm
-operm
=
$perm
-kname
=
1
-
ext_asm
=
1
-
asm
_atomic_fp32
=
0
-v
=
$VALID
;
sleep
3
$EXE
-prec
=
$prec
-b
=
4
-h
=
$nhead
-d
=
$hdim
-s
=
4096
-iperm
=
$perm
-operm
=
$perm
-kname
=
1
-
bwd_v3
=
1
-
v3
_atomic_fp32
=
0
-v
=
$VALID
;
sleep
3
$EXE
-prec
=
$prec
-b
=
2
-h
=
$nhead
-d
=
$hdim
-s
=
8192
-iperm
=
$perm
-operm
=
$perm
-kname
=
1
-v
=
$VALID
;
sleep
3
$EXE
-prec
=
$prec
-b
=
2
-h
=
$nhead
-d
=
$hdim
-s
=
8192
-iperm
=
$perm
-operm
=
$perm
-kname
=
1
-v
=
$VALID
;
sleep
3
$EXE
-prec
=
$prec
-b
=
2
-h
=
$nhead
-d
=
$hdim
-s
=
8192
-iperm
=
$perm
-operm
=
$perm
-kname
=
1
-
ext_asm
=
1
-v
=
$VALID
;
sleep
3
$EXE
-prec
=
$prec
-b
=
2
-h
=
$nhead
-d
=
$hdim
-s
=
8192
-iperm
=
$perm
-operm
=
$perm
-kname
=
1
-
bwd_v3
=
1
-v
=
$VALID
;
sleep
3
$EXE
-prec
=
$prec
-b
=
2
-h
=
$nhead
-d
=
$hdim
-s
=
8192
-iperm
=
$perm
-operm
=
$perm
-kname
=
1
-
ext_asm
=
1
-
asm
_atomic_fp32
=
0
-v
=
$VALID
;
sleep
3
$EXE
-prec
=
$prec
-b
=
2
-h
=
$nhead
-d
=
$hdim
-s
=
8192
-iperm
=
$perm
-operm
=
$perm
-kname
=
1
-
bwd_v3
=
1
-
v3
_atomic_fp32
=
0
-v
=
$VALID
;
sleep
3
$EXE
-prec
=
$prec
-b
=
1
-h
=
$nhead
-d
=
$hdim
-s
=
16384
-iperm
=
$perm
-operm
=
$perm
-kname
=
1
-v
=
$VALID
;
sleep
3
$EXE
-prec
=
$prec
-b
=
1
-h
=
$nhead
-d
=
$hdim
-s
=
16384
-iperm
=
$perm
-operm
=
$perm
-kname
=
1
-v
=
$VALID
;
sleep
3
$EXE
-prec
=
$prec
-b
=
1
-h
=
$nhead
-d
=
$hdim
-s
=
16384
-iperm
=
$perm
-operm
=
$perm
-kname
=
1
-
ext_asm
=
1
-v
=
$VALID
;
sleep
3
$EXE
-prec
=
$prec
-b
=
1
-h
=
$nhead
-d
=
$hdim
-s
=
16384
-iperm
=
$perm
-operm
=
$perm
-kname
=
1
-
bwd_v3
=
1
-v
=
$VALID
;
sleep
3
$EXE
-prec
=
$prec
-b
=
1
-h
=
$nhead
-d
=
$hdim
-s
=
16384
-iperm
=
$perm
-operm
=
$perm
-kname
=
1
-
ext_asm
=
1
-
asm
_atomic_fp32
=
0
-v
=
$VALID
;
sleep
3
$EXE
-prec
=
$prec
-b
=
1
-h
=
$nhead
-d
=
$hdim
-s
=
16384
-iperm
=
$perm
-operm
=
$perm
-kname
=
1
-
bwd_v3
=
1
-
v3
_atomic_fp32
=
0
-v
=
$VALID
;
sleep
3
done
done
done
done
...
...
example/ck_tile/01_fmha/script/smoke_test_bwd_ext.sh
View file @
d4de8495
...
@@ -11,12 +11,12 @@ set -x
...
@@ -11,12 +11,12 @@ set -x
for
prec
in
"fp16"
"bf16"
;
do
for
prec
in
"fp16"
"bf16"
;
do
for
perm
in
0 1
;
do
for
perm
in
0 1
;
do
for
hdim
in
128
;
do
for
hdim
in
128
;
do
for
asm
_atomic_fp32
in
0 1
;
do
for
v3
_atomic_fp32
in
0 1
;
do
for
asm_no_coex
in
0 1
;
do
for
v3_spec
in
0 1
;
do
for
mask
in
0 1
;
do
for
mask
in
0 1
;
do
$EXE
-prec
=
$prec
-b
=
4
-h
=
2
-d
=
$hdim
-s
=
512
-iperm
=
$perm
-operm
=
$perm
-mask
=
$mask
-
ext_asm
=
1
-
asm
_atomic_fp32
=
$
asm
_atomic_fp32
-
asm_no_coex
=
$asm_no_coex
-mode
=
0
-kname
=
$KNAME
$COMMON_ARGS
$EXE
-prec
=
$prec
-b
=
4
-h
=
2
-d
=
$hdim
-s
=
512
-iperm
=
$perm
-operm
=
$perm
-mask
=
$mask
-
bwd_v3
=
1
-
v3
_atomic_fp32
=
$
v3
_atomic_fp32
-
v3_spec
=
$v3_spec
-mode
=
0
-kname
=
$KNAME
$COMMON_ARGS
$EXE
-prec
=
$prec
-b
=
1
-h
=
3
-d
=
$hdim
-s
=
768
-iperm
=
$perm
-operm
=
$perm
-mask
=
$mask
-
ext_asm
=
1
-
asm
_atomic_fp32
=
$
asm
_atomic_fp32
-
asm_no_coex
=
$asm_no_coex
-mode
=
0
-kname
=
$KNAME
$COMMON_ARGS
$EXE
-prec
=
$prec
-b
=
1
-h
=
3
-d
=
$hdim
-s
=
768
-iperm
=
$perm
-operm
=
$perm
-mask
=
$mask
-
bwd_v3
=
1
-
v3
_atomic_fp32
=
$
v3
_atomic_fp32
-
v3_spec
=
$v3_spec
-mode
=
0
-kname
=
$KNAME
$COMMON_ARGS
done
done
done
done
...
...
example/ck_tile/01_fmha/script/smoke_test_bwd_xqa_ext.sh
View file @
d4de8495
...
@@ -13,8 +13,8 @@ for perm in 0 1 ; do
...
@@ -13,8 +13,8 @@ for perm in 0 1 ; do
for
hdim
in
128
;
do
for
hdim
in
128
;
do
for
mask
in
0 1
;
do
for
mask
in
0 1
;
do
$EXE
-prec
=
$prec
-b
=
2
-h
=
4
-h_k
=
2
-d
=
$hdim
-s
=
512
-iperm
=
$perm
-operm
=
$perm
-mask
=
$mask
-
ext_asm
=
1
-
asm
_atomic_fp32
=
0
-
asm
_rtz_cvt
=
1
-mode
=
0
-kname
=
$KNAME
$COMMON_ARGS
$EXE
-prec
=
$prec
-b
=
2
-h
=
4
-h_k
=
2
-d
=
$hdim
-s
=
512
-iperm
=
$perm
-operm
=
$perm
-mask
=
$mask
-
bwd_v3
=
1
-
v3
_atomic_fp32
=
0
-
v3
_rtz_cvt
=
1
-mode
=
0
-kname
=
$KNAME
$COMMON_ARGS
$EXE
-prec
=
$prec
-b
=
1
-h
=
3
-h_k
=
1
-d
=
$hdim
-s
=
768
-iperm
=
$perm
-operm
=
$perm
-mask
=
$mask
-
ext_asm
=
1
-
asm
_atomic_fp32
=
0
-
asm
_rtz_cvt
=
1
-mode
=
0
-kname
=
$KNAME
$COMMON_ARGS
$EXE
-prec
=
$prec
-b
=
1
-h
=
3
-h_k
=
1
-d
=
$hdim
-s
=
768
-iperm
=
$perm
-operm
=
$perm
-mask
=
$mask
-
bwd_v3
=
1
-
v3
_atomic_fp32
=
0
-
v3
_rtz_cvt
=
1
-mode
=
0
-kname
=
$KNAME
$COMMON_ARGS
done
done
done
done
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment