Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
2a12493d
"...composable_kernel_rocm.git" did not exist on "9a7fa123fdade18ffa125b8d1647d24dda6f889d"
Commit
2a12493d
authored
Dec 17, 2024
by
Muhammed Emin Ozturk
Browse files
gemm_universal_streamk.hpp
parent
f342446c
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
72 additions
and
72 deletions
+72
-72
library/include/ck/library/tensor_operation_instance/gpu/gemm_universal_streamk.hpp
.../tensor_operation_instance/gpu/gemm_universal_streamk.hpp
+72
-72
No files found.
library/include/ck/library/tensor_operation_instance/gpu/gemm_universal_streamk.hpp
View file @
2a12493d
...
@@ -240,181 +240,181 @@ namespace instance {
...
@@ -240,181 +240,181 @@ namespace instance {
// Emin @Added
// Emin @Added
#ifdef CK_ENABLE_BF16
#ifdef CK_ENABLE_BF16
void
add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn_comp_default_instances
(
void
add_device_gemm_xdl_universal_
streamk_
bf16_bf16_bf16_mk_kn_mn_comp_default_instances
(
std
::
vector
<
std
::
unique_ptr
<
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmV2
<
Row
,
Row
,
Row
,
BF16
,
BF16
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
DeviceGemmV2
<
Row
,
Row
,
Row
,
BF16
,
BF16
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
instances
);
void
add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn_comp_kpadding_instances
(
void
add_device_gemm_xdl_universal_
streamk_
bf16_bf16_bf16_mk_kn_mn_comp_kpadding_instances
(
std
::
vector
<
std
::
unique_ptr
<
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmV2
<
Row
,
Row
,
Row
,
BF16
,
BF16
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
DeviceGemmV2
<
Row
,
Row
,
Row
,
BF16
,
BF16
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
instances
);
void
add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn_comp_mnpadding_instances
(
void
add_device_gemm_xdl_universal_
streamk_
bf16_bf16_bf16_mk_kn_mn_comp_mnpadding_instances
(
std
::
vector
<
std
::
unique_ptr
<
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmV2
<
Row
,
Row
,
Row
,
BF16
,
BF16
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
DeviceGemmV2
<
Row
,
Row
,
Row
,
BF16
,
BF16
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
instances
);
void
add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn_comp_mnkpadding_instances
(
void
add_device_gemm_xdl_universal_
streamk_
bf16_bf16_bf16_mk_kn_mn_comp_mnkpadding_instances
(
std
::
vector
<
std
::
unique_ptr
<
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmV2
<
Row
,
Row
,
Row
,
BF16
,
BF16
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
DeviceGemmV2
<
Row
,
Row
,
Row
,
BF16
,
BF16
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
instances
);
void
add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn_mem_v1_default_instances
(
void
add_device_gemm_xdl_universal_
streamk_
bf16_bf16_bf16_mk_kn_mn_mem_v1_default_instances
(
std
::
vector
<
std
::
unique_ptr
<
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmV2
<
Row
,
Row
,
Row
,
BF16
,
BF16
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
DeviceGemmV2
<
Row
,
Row
,
Row
,
BF16
,
BF16
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
instances
);
void
add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn_mem_v1_kpadding_instances
(
void
add_device_gemm_xdl_universal_
streamk_
bf16_bf16_bf16_mk_kn_mn_mem_v1_kpadding_instances
(
std
::
vector
<
std
::
unique_ptr
<
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmV2
<
Row
,
Row
,
Row
,
BF16
,
BF16
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
DeviceGemmV2
<
Row
,
Row
,
Row
,
BF16
,
BF16
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
instances
);
void
add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn_mem_v1_mnkpadding_instances
(
void
add_device_gemm_xdl_universal_
streamk_
bf16_bf16_bf16_mk_kn_mn_mem_v1_mnkpadding_instances
(
std
::
vector
<
std
::
unique_ptr
<
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmV2
<
Row
,
Row
,
Row
,
BF16
,
BF16
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
DeviceGemmV2
<
Row
,
Row
,
Row
,
BF16
,
BF16
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
instances
);
void
add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn_mem_v2_default_instances
(
void
add_device_gemm_xdl_universal_
streamk_
bf16_bf16_bf16_mk_kn_mn_mem_v2_default_instances
(
std
::
vector
<
std
::
unique_ptr
<
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmV2
<
Row
,
Row
,
Row
,
BF16
,
BF16
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
DeviceGemmV2
<
Row
,
Row
,
Row
,
BF16
,
BF16
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
instances
);
void
add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn_mem_v2_kpadding_instances
(
void
add_device_gemm_xdl_universal_
streamk_
bf16_bf16_bf16_mk_kn_mn_mem_v2_kpadding_instances
(
std
::
vector
<
std
::
unique_ptr
<
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmV2
<
Row
,
Row
,
Row
,
BF16
,
BF16
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
DeviceGemmV2
<
Row
,
Row
,
Row
,
BF16
,
BF16
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
instances
);
void
add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn_mem_v2_mnkpadding_instances
(
void
add_device_gemm_xdl_universal_
streamk_
bf16_bf16_bf16_mk_kn_mn_mem_v2_mnkpadding_instances
(
std
::
vector
<
std
::
unique_ptr
<
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmV2
<
Row
,
Row
,
Row
,
BF16
,
BF16
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
DeviceGemmV2
<
Row
,
Row
,
Row
,
BF16
,
BF16
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
instances
);
void
add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_comp_default_instances
(
void
add_device_gemm_xdl_universal_
streamk_
bf16_bf16_bf16_mk_nk_mn_comp_default_instances
(
std
::
vector
<
std
::
unique_ptr
<
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmV2
<
Row
,
Col
,
Row
,
BF16
,
BF16
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
DeviceGemmV2
<
Row
,
Col
,
Row
,
BF16
,
BF16
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
instances
);
void
add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_comp_kpadding_instances
(
void
add_device_gemm_xdl_universal_
streamk_
bf16_bf16_bf16_mk_nk_mn_comp_kpadding_instances
(
std
::
vector
<
std
::
unique_ptr
<
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmV2
<
Row
,
Col
,
Row
,
BF16
,
BF16
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
DeviceGemmV2
<
Row
,
Col
,
Row
,
BF16
,
BF16
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
instances
);
void
add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_mem_v1_default_instances
(
void
add_device_gemm_xdl_universal_
streamk_
bf16_bf16_bf16_mk_nk_mn_mem_v1_default_instances
(
std
::
vector
<
std
::
unique_ptr
<
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmV2
<
Row
,
Col
,
Row
,
BF16
,
BF16
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
DeviceGemmV2
<
Row
,
Col
,
Row
,
BF16
,
BF16
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
instances
);
void
add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_mem_v1_kpadding_instances
(
void
add_device_gemm_xdl_universal_
streamk_
bf16_bf16_bf16_mk_nk_mn_mem_v1_kpadding_instances
(
std
::
vector
<
std
::
unique_ptr
<
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmV2
<
Row
,
Col
,
Row
,
BF16
,
BF16
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
DeviceGemmV2
<
Row
,
Col
,
Row
,
BF16
,
BF16
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
instances
);
void
add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_mem_v2_default_instances
(
void
add_device_gemm_xdl_universal_
streamk_
bf16_bf16_bf16_mk_nk_mn_mem_v2_default_instances
(
std
::
vector
<
std
::
unique_ptr
<
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmV2
<
Row
,
Col
,
Row
,
BF16
,
BF16
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
DeviceGemmV2
<
Row
,
Col
,
Row
,
BF16
,
BF16
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
instances
);
void
add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_mem_v2_kpadding_instances
(
void
add_device_gemm_xdl_universal_
streamk_
bf16_bf16_bf16_mk_nk_mn_mem_v2_kpadding_instances
(
std
::
vector
<
std
::
unique_ptr
<
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmV2
<
Row
,
Col
,
Row
,
BF16
,
BF16
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
DeviceGemmV2
<
Row
,
Col
,
Row
,
BF16
,
BF16
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
instances
);
void
add_device_gemm_xdl_universal_bf16_bf16_bf16_km_kn_mn_comp_default_instances
(
void
add_device_gemm_xdl_universal_
streamk_
bf16_bf16_bf16_km_kn_mn_comp_default_instances
(
std
::
vector
<
std
::
unique_ptr
<
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmV2
<
Col
,
Row
,
Row
,
BF16
,
BF16
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
DeviceGemmV2
<
Col
,
Row
,
Row
,
BF16
,
BF16
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
instances
);
void
add_device_gemm_xdl_universal_bf16_bf16_bf16_km_kn_mn_comp_kpadding_instances
(
void
add_device_gemm_xdl_universal_
streamk_
bf16_bf16_bf16_km_kn_mn_comp_kpadding_instances
(
std
::
vector
<
std
::
unique_ptr
<
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmV2
<
Col
,
Row
,
Row
,
BF16
,
BF16
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
DeviceGemmV2
<
Col
,
Row
,
Row
,
BF16
,
BF16
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
instances
);
void
add_device_gemm_xdl_universal_bf16_bf16_bf16_km_kn_mn_comp_mnpadding_instances
(
void
add_device_gemm_xdl_universal_
streamk_
bf16_bf16_bf16_km_kn_mn_comp_mnpadding_instances
(
std
::
vector
<
std
::
unique_ptr
<
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmV2
<
Col
,
Row
,
Row
,
BF16
,
BF16
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
DeviceGemmV2
<
Col
,
Row
,
Row
,
BF16
,
BF16
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
instances
);
void
add_device_gemm_xdl_universal_bf16_bf16_bf16_km_kn_mn_comp_mnkpadding_instances
(
void
add_device_gemm_xdl_universal_
streamk_
bf16_bf16_bf16_km_kn_mn_comp_mnkpadding_instances
(
std
::
vector
<
std
::
unique_ptr
<
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmV2
<
Col
,
Row
,
Row
,
BF16
,
BF16
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
DeviceGemmV2
<
Col
,
Row
,
Row
,
BF16
,
BF16
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
instances
);
void
add_device_gemm_xdl_universal_bf16_bf16_bf16_km_kn_mn_mem_v1_default_instances
(
void
add_device_gemm_xdl_universal_
streamk_
bf16_bf16_bf16_km_kn_mn_mem_v1_default_instances
(
std
::
vector
<
std
::
unique_ptr
<
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmV2
<
Col
,
Row
,
Row
,
BF16
,
BF16
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
DeviceGemmV2
<
Col
,
Row
,
Row
,
BF16
,
BF16
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
instances
);
void
add_device_gemm_xdl_universal_bf16_bf16_bf16_km_kn_mn_mem_v1_kpadding_instances
(
void
add_device_gemm_xdl_universal_
streamk_
bf16_bf16_bf16_km_kn_mn_mem_v1_kpadding_instances
(
std
::
vector
<
std
::
unique_ptr
<
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmV2
<
Col
,
Row
,
Row
,
BF16
,
BF16
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
DeviceGemmV2
<
Col
,
Row
,
Row
,
BF16
,
BF16
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
instances
);
void
add_device_gemm_xdl_universal_bf16_bf16_bf16_km_kn_mn_mem_v1_mnkpadding_instances
(
void
add_device_gemm_xdl_universal_
streamk_
bf16_bf16_bf16_km_kn_mn_mem_v1_mnkpadding_instances
(
std
::
vector
<
std
::
unique_ptr
<
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmV2
<
Col
,
Row
,
Row
,
BF16
,
BF16
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
DeviceGemmV2
<
Col
,
Row
,
Row
,
BF16
,
BF16
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
instances
);
void
add_device_gemm_xdl_universal_bf16_bf16_bf16_km_kn_mn_mem_v2_default_instances
(
void
add_device_gemm_xdl_universal_
streamk_
bf16_bf16_bf16_km_kn_mn_mem_v2_default_instances
(
std
::
vector
<
std
::
unique_ptr
<
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmV2
<
Col
,
Row
,
Row
,
BF16
,
BF16
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
DeviceGemmV2
<
Col
,
Row
,
Row
,
BF16
,
BF16
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
instances
);
void
add_device_gemm_xdl_universal_bf16_bf16_bf16_km_kn_mn_mem_v2_kpadding_instances
(
void
add_device_gemm_xdl_universal_
streamk_
bf16_bf16_bf16_km_kn_mn_mem_v2_kpadding_instances
(
std
::
vector
<
std
::
unique_ptr
<
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmV2
<
Col
,
Row
,
Row
,
BF16
,
BF16
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
DeviceGemmV2
<
Col
,
Row
,
Row
,
BF16
,
BF16
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
instances
);
void
add_device_gemm_xdl_universal_bf16_bf16_bf16_km_kn_mn_mem_v2_mnkpadding_instances
(
void
add_device_gemm_xdl_universal_
streamk_
bf16_bf16_bf16_km_kn_mn_mem_v2_mnkpadding_instances
(
std
::
vector
<
std
::
unique_ptr
<
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmV2
<
Col
,
Row
,
Row
,
BF16
,
BF16
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
DeviceGemmV2
<
Col
,
Row
,
Row
,
BF16
,
BF16
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
instances
);
void
add_device_gemm_xdl_universal_bf16_bf16_bf16_km_nk_mn_comp_default_instances
(
void
add_device_gemm_xdl_universal_
streamk_
bf16_bf16_bf16_km_nk_mn_comp_default_instances
(
std
::
vector
<
std
::
unique_ptr
<
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmV2
<
Col
,
Col
,
Row
,
BF16
,
BF16
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
DeviceGemmV2
<
Col
,
Col
,
Row
,
BF16
,
BF16
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
instances
);
void
add_device_gemm_xdl_universal_bf16_bf16_bf16_km_nk_mn_comp_kpadding_instances
(
void
add_device_gemm_xdl_universal_
streamk_
bf16_bf16_bf16_km_nk_mn_comp_kpadding_instances
(
std
::
vector
<
std
::
unique_ptr
<
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmV2
<
Col
,
Col
,
Row
,
BF16
,
BF16
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
DeviceGemmV2
<
Col
,
Col
,
Row
,
BF16
,
BF16
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
instances
);
void
add_device_gemm_xdl_universal_bf16_bf16_bf16_km_nk_mn_comp_mpadding_instances
(
void
add_device_gemm_xdl_universal_
streamk_
bf16_bf16_bf16_km_nk_mn_comp_mpadding_instances
(
std
::
vector
<
std
::
unique_ptr
<
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmV2
<
Col
,
Col
,
Row
,
BF16
,
BF16
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
DeviceGemmV2
<
Col
,
Col
,
Row
,
BF16
,
BF16
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
instances
);
void
add_device_gemm_xdl_universal_bf16_bf16_bf16_km_nk_mn_comp_mkpadding_instances
(
void
add_device_gemm_xdl_universal_
streamk_
bf16_bf16_bf16_km_nk_mn_comp_mkpadding_instances
(
std
::
vector
<
std
::
unique_ptr
<
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmV2
<
Col
,
Col
,
Row
,
BF16
,
BF16
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
DeviceGemmV2
<
Col
,
Col
,
Row
,
BF16
,
BF16
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
instances
);
void
add_device_gemm_xdl_universal_bf16_bf16_bf16_km_nk_mn_mem_v1_default_instances
(
void
add_device_gemm_xdl_universal_
streamk_
bf16_bf16_bf16_km_nk_mn_mem_v1_default_instances
(
std
::
vector
<
std
::
unique_ptr
<
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmV2
<
Col
,
Col
,
Row
,
BF16
,
BF16
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
DeviceGemmV2
<
Col
,
Col
,
Row
,
BF16
,
BF16
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
instances
);
void
add_device_gemm_xdl_universal_bf16_bf16_bf16_km_nk_mn_mem_v1_kpadding_instances
(
void
add_device_gemm_xdl_universal_
streamk_
bf16_bf16_bf16_km_nk_mn_mem_v1_kpadding_instances
(
std
::
vector
<
std
::
unique_ptr
<
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmV2
<
Col
,
Col
,
Row
,
BF16
,
BF16
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
DeviceGemmV2
<
Col
,
Col
,
Row
,
BF16
,
BF16
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
instances
);
void
add_device_gemm_xdl_universal_bf16_bf16_bf16_km_nk_mn_mem_v1_mkpadding_instances
(
void
add_device_gemm_xdl_universal_
streamk_
bf16_bf16_bf16_km_nk_mn_mem_v1_mkpadding_instances
(
std
::
vector
<
std
::
unique_ptr
<
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmV2
<
Col
,
Col
,
Row
,
BF16
,
BF16
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
DeviceGemmV2
<
Col
,
Col
,
Row
,
BF16
,
BF16
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
instances
);
void
add_device_gemm_xdl_universal_bf16_bf16_bf16_km_nk_mn_mem_v2_default_instances
(
void
add_device_gemm_xdl_universal_
streamk_
bf16_bf16_bf16_km_nk_mn_mem_v2_default_instances
(
std
::
vector
<
std
::
unique_ptr
<
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmV2
<
Col
,
Col
,
Row
,
BF16
,
BF16
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
DeviceGemmV2
<
Col
,
Col
,
Row
,
BF16
,
BF16
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
instances
);
void
add_device_gemm_xdl_universal_bf16_bf16_bf16_km_nk_mn_mem_v2_kpadding_instances
(
void
add_device_gemm_xdl_universal_
streamk_
bf16_bf16_bf16_km_nk_mn_mem_v2_kpadding_instances
(
std
::
vector
<
std
::
unique_ptr
<
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmV2
<
Col
,
Col
,
Row
,
BF16
,
BF16
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
DeviceGemmV2
<
Col
,
Col
,
Row
,
BF16
,
BF16
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
instances
);
void
add_device_gemm_xdl_universal_bf16_bf16_bf16_km_nk_mn_mem_v2_mkpadding_instances
(
void
add_device_gemm_xdl_universal_
streamk_
bf16_bf16_bf16_km_nk_mn_mem_v2_mkpadding_instances
(
std
::
vector
<
std
::
unique_ptr
<
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmV2
<
Col
,
Col
,
Row
,
BF16
,
BF16
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
DeviceGemmV2
<
Col
,
Col
,
Row
,
BF16
,
BF16
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
instances
);
...
@@ -719,97 +719,97 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGemm_S
...
@@ -719,97 +719,97 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGemm_S
if
constexpr
(
is_same_v
<
ALayout
,
Row
>
&&
is_same_v
<
BLayout
,
Row
>
&&
if
constexpr
(
is_same_v
<
ALayout
,
Row
>
&&
is_same_v
<
BLayout
,
Row
>
&&
is_same_v
<
CLayout
,
Row
>
)
is_same_v
<
CLayout
,
Row
>
)
{
{
add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn_comp_default_instances
(
add_device_gemm_xdl_universal_
streamk_
bf16_bf16_bf16_mk_kn_mn_comp_default_instances
(
op_ptrs
);
op_ptrs
);
add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn_comp_kpadding_instances
(
add_device_gemm_xdl_universal_
streamk_
bf16_bf16_bf16_mk_kn_mn_comp_kpadding_instances
(
op_ptrs
);
op_ptrs
);
add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn_comp_mnpadding_instances
(
add_device_gemm_xdl_universal_
streamk_
bf16_bf16_bf16_mk_kn_mn_comp_mnpadding_instances
(
op_ptrs
);
op_ptrs
);
add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn_comp_mnkpadding_instances
(
add_device_gemm_xdl_universal_
streamk_
bf16_bf16_bf16_mk_kn_mn_comp_mnkpadding_instances
(
op_ptrs
);
op_ptrs
);
add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn_mem_v1_default_instances
(
add_device_gemm_xdl_universal_
streamk_
bf16_bf16_bf16_mk_kn_mn_mem_v1_default_instances
(
op_ptrs
);
op_ptrs
);
add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn_mem_v1_kpadding_instances
(
add_device_gemm_xdl_universal_
streamk_
bf16_bf16_bf16_mk_kn_mn_mem_v1_kpadding_instances
(
op_ptrs
);
op_ptrs
);
add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn_mem_v1_mnkpadding_instances
(
add_device_gemm_xdl_universal_
streamk_
bf16_bf16_bf16_mk_kn_mn_mem_v1_mnkpadding_instances
(
op_ptrs
);
op_ptrs
);
add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn_mem_v2_default_instances
(
add_device_gemm_xdl_universal_
streamk_
bf16_bf16_bf16_mk_kn_mn_mem_v2_default_instances
(
op_ptrs
);
op_ptrs
);
add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn_mem_v2_kpadding_instances
(
add_device_gemm_xdl_universal_
streamk_
bf16_bf16_bf16_mk_kn_mn_mem_v2_kpadding_instances
(
op_ptrs
);
op_ptrs
);
add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn_mem_v2_mnkpadding_instances
(
add_device_gemm_xdl_universal_
streamk_
bf16_bf16_bf16_mk_kn_mn_mem_v2_mnkpadding_instances
(
op_ptrs
);
op_ptrs
);
}
}
else
if
constexpr
(
is_same_v
<
ALayout
,
Row
>
&&
is_same_v
<
BLayout
,
Col
>
&&
else
if
constexpr
(
is_same_v
<
ALayout
,
Row
>
&&
is_same_v
<
BLayout
,
Col
>
&&
is_same_v
<
CLayout
,
Row
>
)
is_same_v
<
CLayout
,
Row
>
)
{
{
add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_comp_default_instances
(
add_device_gemm_xdl_universal_
streamk_
bf16_bf16_bf16_mk_nk_mn_comp_default_instances
(
op_ptrs
);
op_ptrs
);
add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_comp_kpadding_instances
(
add_device_gemm_xdl_universal_
streamk_
bf16_bf16_bf16_mk_nk_mn_comp_kpadding_instances
(
op_ptrs
);
op_ptrs
);
add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_mem_v1_default_instances
(
add_device_gemm_xdl_universal_
streamk_
bf16_bf16_bf16_mk_nk_mn_mem_v1_default_instances
(
op_ptrs
);
op_ptrs
);
add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_mem_v1_kpadding_instances
(
add_device_gemm_xdl_universal_
streamk_
bf16_bf16_bf16_mk_nk_mn_mem_v1_kpadding_instances
(
op_ptrs
);
op_ptrs
);
add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_mem_v2_default_instances
(
add_device_gemm_xdl_universal_
streamk_
bf16_bf16_bf16_mk_nk_mn_mem_v2_default_instances
(
op_ptrs
);
op_ptrs
);
add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_mem_v2_kpadding_instances
(
add_device_gemm_xdl_universal_
streamk_
bf16_bf16_bf16_mk_nk_mn_mem_v2_kpadding_instances
(
op_ptrs
);
op_ptrs
);
}
}
else
if
constexpr
(
is_same_v
<
ALayout
,
Col
>
&&
is_same_v
<
BLayout
,
Row
>
&&
else
if
constexpr
(
is_same_v
<
ALayout
,
Col
>
&&
is_same_v
<
BLayout
,
Row
>
&&
is_same_v
<
CLayout
,
Row
>
)
is_same_v
<
CLayout
,
Row
>
)
{
{
add_device_gemm_xdl_universal_bf16_bf16_bf16_km_kn_mn_comp_default_instances
(
add_device_gemm_xdl_universal_
streamk_
bf16_bf16_bf16_km_kn_mn_comp_default_instances
(
op_ptrs
);
op_ptrs
);
add_device_gemm_xdl_universal_bf16_bf16_bf16_km_kn_mn_comp_kpadding_instances
(
add_device_gemm_xdl_universal_
streamk_
bf16_bf16_bf16_km_kn_mn_comp_kpadding_instances
(
op_ptrs
);
op_ptrs
);
add_device_gemm_xdl_universal_bf16_bf16_bf16_km_kn_mn_comp_mnpadding_instances
(
add_device_gemm_xdl_universal_
streamk_
bf16_bf16_bf16_km_kn_mn_comp_mnpadding_instances
(
op_ptrs
);
op_ptrs
);
add_device_gemm_xdl_universal_bf16_bf16_bf16_km_kn_mn_comp_mnkpadding_instances
(
add_device_gemm_xdl_universal_
streamk_
bf16_bf16_bf16_km_kn_mn_comp_mnkpadding_instances
(
op_ptrs
);
op_ptrs
);
add_device_gemm_xdl_universal_bf16_bf16_bf16_km_kn_mn_mem_v1_default_instances
(
add_device_gemm_xdl_universal_
streamk_
bf16_bf16_bf16_km_kn_mn_mem_v1_default_instances
(
op_ptrs
);
op_ptrs
);
add_device_gemm_xdl_universal_bf16_bf16_bf16_km_kn_mn_mem_v1_kpadding_instances
(
add_device_gemm_xdl_universal_
streamk_
bf16_bf16_bf16_km_kn_mn_mem_v1_kpadding_instances
(
op_ptrs
);
op_ptrs
);
add_device_gemm_xdl_universal_bf16_bf16_bf16_km_kn_mn_mem_v1_mnkpadding_instances
(
add_device_gemm_xdl_universal_
streamk_
bf16_bf16_bf16_km_kn_mn_mem_v1_mnkpadding_instances
(
op_ptrs
);
op_ptrs
);
add_device_gemm_xdl_universal_bf16_bf16_bf16_km_kn_mn_mem_v2_default_instances
(
add_device_gemm_xdl_universal_
streamk_
bf16_bf16_bf16_km_kn_mn_mem_v2_default_instances
(
op_ptrs
);
op_ptrs
);
add_device_gemm_xdl_universal_bf16_bf16_bf16_km_kn_mn_mem_v2_kpadding_instances
(
add_device_gemm_xdl_universal_
streamk_
bf16_bf16_bf16_km_kn_mn_mem_v2_kpadding_instances
(
op_ptrs
);
op_ptrs
);
add_device_gemm_xdl_universal_bf16_bf16_bf16_km_kn_mn_mem_v2_mnkpadding_instances
(
add_device_gemm_xdl_universal_
streamk_
bf16_bf16_bf16_km_kn_mn_mem_v2_mnkpadding_instances
(
op_ptrs
);
op_ptrs
);
}
}
else
if
constexpr
(
is_same_v
<
ALayout
,
Col
>
&&
is_same_v
<
BLayout
,
Col
>
&&
else
if
constexpr
(
is_same_v
<
ALayout
,
Col
>
&&
is_same_v
<
BLayout
,
Col
>
&&
is_same_v
<
CLayout
,
Row
>
)
is_same_v
<
CLayout
,
Row
>
)
{
{
add_device_gemm_xdl_universal_bf16_bf16_bf16_km_nk_mn_comp_default_instances
(
add_device_gemm_xdl_universal_
streamk_
bf16_bf16_bf16_km_nk_mn_comp_default_instances
(
op_ptrs
);
op_ptrs
);
add_device_gemm_xdl_universal_bf16_bf16_bf16_km_nk_mn_comp_kpadding_instances
(
add_device_gemm_xdl_universal_
streamk_
bf16_bf16_bf16_km_nk_mn_comp_kpadding_instances
(
op_ptrs
);
op_ptrs
);
add_device_gemm_xdl_universal_bf16_bf16_bf16_km_nk_mn_comp_mpadding_instances
(
add_device_gemm_xdl_universal_
streamk_
bf16_bf16_bf16_km_nk_mn_comp_mpadding_instances
(
op_ptrs
);
op_ptrs
);
add_device_gemm_xdl_universal_bf16_bf16_bf16_km_nk_mn_comp_mkpadding_instances
(
add_device_gemm_xdl_universal_
streamk_
bf16_bf16_bf16_km_nk_mn_comp_mkpadding_instances
(
op_ptrs
);
op_ptrs
);
add_device_gemm_xdl_universal_bf16_bf16_bf16_km_nk_mn_mem_v1_default_instances
(
add_device_gemm_xdl_universal_
streamk_
bf16_bf16_bf16_km_nk_mn_mem_v1_default_instances
(
op_ptrs
);
op_ptrs
);
add_device_gemm_xdl_universal_bf16_bf16_bf16_km_nk_mn_mem_v1_kpadding_instances
(
add_device_gemm_xdl_universal_
streamk_
bf16_bf16_bf16_km_nk_mn_mem_v1_kpadding_instances
(
op_ptrs
);
op_ptrs
);
add_device_gemm_xdl_universal_bf16_bf16_bf16_km_nk_mn_mem_v1_mkpadding_instances
(
add_device_gemm_xdl_universal_
streamk_
bf16_bf16_bf16_km_nk_mn_mem_v1_mkpadding_instances
(
op_ptrs
);
op_ptrs
);
add_device_gemm_xdl_universal_bf16_bf16_bf16_km_nk_mn_mem_v2_default_instances
(
add_device_gemm_xdl_universal_
streamk_
bf16_bf16_bf16_km_nk_mn_mem_v2_default_instances
(
op_ptrs
);
op_ptrs
);
add_device_gemm_xdl_universal_bf16_bf16_bf16_km_nk_mn_mem_v2_kpadding_instances
(
add_device_gemm_xdl_universal_
streamk_
bf16_bf16_bf16_km_nk_mn_mem_v2_kpadding_instances
(
op_ptrs
);
op_ptrs
);
add_device_gemm_xdl_universal_bf16_bf16_bf16_km_nk_mn_mem_v2_mkpadding_instances
(
add_device_gemm_xdl_universal_
streamk_
bf16_bf16_bf16_km_nk_mn_mem_v2_mkpadding_instances
(
op_ptrs
);
op_ptrs
);
}
}
}
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment