Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
a6ef5c39
Commit
a6ef5c39
authored
May 17, 2024
by
Jun Liu
Browse files
Merge branch 'develop' into amd-develop
parents
9b3c4ac4
1274861a
Changes
39
Hide whitespace changes
Inline
Side-by-side
Showing
19 changed files
with
234 additions
and
52 deletions
+234
-52
include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_dl.hpp
...ion/gpu/device/impl/device_grouped_gemm_multiple_d_dl.hpp
+1
-1
include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_two_stage.hpp
...grouped_gemm_multiple_d_splitk_xdl_cshuffle_two_stage.hpp
+10
-9
include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_xdl_cshuffle_tile_loop.hpp
...device_grouped_gemm_multiple_d_xdl_cshuffle_tile_loop.hpp
+1
-1
include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl.hpp
...sor_operation/gpu/device/impl/device_grouped_gemm_xdl.hpp
+1
-1
include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_splitk_cshuffle.hpp
...u/device/impl/device_grouped_gemm_xdl_splitk_cshuffle.hpp
+2
-2
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3.hpp
...nsor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3.hpp
+10
-10
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_abd.hpp
...tion/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_abd.hpp
+9
-9
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4r2.hpp
...tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4r2.hpp
+10
-10
include/ck/utility/env.hpp
include/ck/utility/env.hpp
+1
-1
include/ck_tile/core/numeric/half.hpp
include/ck_tile/core/numeric/half.hpp
+2
-2
include/ck_tile/core/numeric/integral_constant.hpp
include/ck_tile/core/numeric/integral_constant.hpp
+0
-1
profiler/include/profiler/profile_grouped_gemm_fixed_nk_impl.hpp
...r/include/profiler/profile_grouped_gemm_fixed_nk_impl.hpp
+1
-1
profiler/include/profiler/profile_grouped_gemm_impl.hpp
profiler/include/profiler/profile_grouped_gemm_impl.hpp
+1
-1
profiler/include/profiler/profile_grouped_gemm_tile_loop_impl.hpp
.../include/profiler/profile_grouped_gemm_tile_loop_impl.hpp
+1
-1
profiler/include/profiler/profile_grouped_gemm_two_stage_impl.hpp
.../include/profiler/profile_grouped_gemm_two_stage_impl.hpp
+1
-1
test/grouped_gemm/CMakeLists.txt
test/grouped_gemm/CMakeLists.txt
+6
-0
test/grouped_gemm/test_grouped_gemm_two_stage_multiple_d_splitk_xdl.cpp
...emm/test_grouped_gemm_two_stage_multiple_d_splitk_xdl.cpp
+62
-0
test/grouped_gemm/test_grouped_gemm_two_stage_ut_cases.inc
test/grouped_gemm/test_grouped_gemm_two_stage_ut_cases.inc
+61
-0
test/grouped_gemm/test_grouped_gemm_util.hpp
test/grouped_gemm/test_grouped_gemm_util.hpp
+54
-1
No files found.
include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_dl.hpp
View file @
a6ef5c39
...
@@ -553,7 +553,7 @@ struct DeviceGroupedGemmMultipleD_Dl : public DeviceGroupedGemm<ALayout,
...
@@ -553,7 +553,7 @@ struct DeviceGroupedGemmMultipleD_Dl : public DeviceGroupedGemm<ALayout,
for
(
std
::
size_t
i
=
0
;
i
<
arg
.
gemm_desc_kernel_arg_
.
size
();
i
++
)
for
(
std
::
size_t
i
=
0
;
i
<
arg
.
gemm_desc_kernel_arg_
.
size
();
i
++
)
{
{
if
(
ck
::
EnvIsEnabled
(
ENV
(
CK_LOGGING
)))
if
(
ck
::
EnvIsEnabled
(
CK_
ENV
(
CK_LOGGING
)))
{
{
std
::
cout
<<
"group: "
<<
i
<<
" arg.a_grid_desc_k0_m_k1_{"
std
::
cout
<<
"group: "
<<
i
<<
" arg.a_grid_desc_k0_m_k1_{"
<<
arg
.
gemm_desc_kernel_arg_
[
i
].
a_grid_desc_k0_m_k1_
.
GetLength
(
I0
)
<<
arg
.
gemm_desc_kernel_arg_
[
i
].
a_grid_desc_k0_m_k1_
.
GetLength
(
I0
)
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_two_stage.hpp
View file @
a6ef5c39
...
@@ -337,6 +337,7 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage
...
@@ -337,6 +337,7 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage
elementwise_d_grid_descs_m_n_
.
reserve
(
group_count_
);
elementwise_d_grid_descs_m_n_
.
reserve
(
group_count_
);
ds_grid_pointer_
.
reserve
(
group_count_
);
ds_grid_pointer_
.
reserve
(
group_count_
);
group_grid_size_
.
reserve
(
group_count_
);
group_grid_size_
.
reserve
(
group_count_
);
e_ptrs_
.
reserve
(
group_count_
);
for
(
std
::
size_t
i
=
0
;
i
<
gemm_descs
.
size
();
++
i
)
for
(
std
::
size_t
i
=
0
;
i
<
gemm_descs
.
size
();
++
i
)
{
{
...
@@ -380,7 +381,7 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage
...
@@ -380,7 +381,7 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage
const
index_t
block_end
=
grid_size_
+
grid_size_grp
;
const
index_t
block_end
=
grid_size_
+
grid_size_grp
;
grid_size_
+=
grid_size_grp
;
grid_size_
+=
grid_size_grp
;
group_grid_size_
[
i
]
=
grid_size_grp
;
group_grid_size_
.
push_back
(
grid_size_grp
)
;
// block-to-e-tile map
// block-to-e-tile map
auto
grouped_block_2_ctile_map
=
auto
grouped_block_2_ctile_map
=
GroupedGemmBlock2ETileMap
(
local_b2c_tile_map
,
block_start
);
GroupedGemmBlock2ETileMap
(
local_b2c_tile_map
,
block_start
);
...
@@ -421,9 +422,9 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage
...
@@ -421,9 +422,9 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage
elementwise_c_grid_descs_m_n_
.
push_back
(
c_grid_desc_m_n
);
elementwise_c_grid_descs_m_n_
.
push_back
(
c_grid_desc_m_n
);
elementwise_d_grid_descs_m_n_
.
push_back
(
ds_grid_desc_m_n
);
elementwise_d_grid_descs_m_n_
.
push_back
(
ds_grid_desc_m_n
);
ds_grid_pointer_
.
push_back
(
p_ds_grid
);
ds_grid_pointer_
.
push_back
(
p_ds_grid
);
// Store a copy of E pointers for elementwise kernel destination
e_ptrs_
.
push_back
(
p_Es
[
i
]);
}
}
// Store a copy of E pointers for elementwise kernel destination
e_ptrs_
=
p_Es
;
}
}
/**
/**
...
@@ -467,7 +468,7 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage
...
@@ -467,7 +468,7 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage
gemm_kernel_args_
[
i
].
block_start_
=
block_start
;
gemm_kernel_args_
[
i
].
block_start_
=
block_start
;
gemm_kernel_args_
[
i
].
block_end_
=
block_end
;
gemm_kernel_args_
[
i
].
block_end_
=
block_end
;
if
(
ck
::
EnvIsEnabled
(
ENV
(
CK_LOGGING
)))
if
(
ck
::
EnvIsEnabled
(
CK_
ENV
(
CK_LOGGING
)))
{
{
index_t
tiles
=
(
block_end
-
block_start
)
/
K_BATCH
;
index_t
tiles
=
(
block_end
-
block_start
)
/
K_BATCH
;
std
::
cout
<<
"block_start: "
<<
block_start
<<
"
\n
"
std
::
cout
<<
"block_start: "
<<
block_start
<<
"
\n
"
...
@@ -494,7 +495,7 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage
...
@@ -494,7 +495,7 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage
arg
.
karg_
.
p_c_grid
=
p_workspace
+
offset
;
arg
.
karg_
.
p_c_grid
=
p_workspace
+
offset
;
index_t
tiles
=
(
arg
.
block_end_
-
arg
.
block_start_
)
/
arg
.
karg_
.
k_batch
;
index_t
tiles
=
(
arg
.
block_end_
-
arg
.
block_start_
)
/
arg
.
karg_
.
k_batch
;
offset
+=
tiles
*
MPerBlock
*
NPerBlock
;
offset
+=
tiles
*
MPerBlock
*
NPerBlock
;
if
(
ck
::
EnvIsEnabled
(
ENV
(
CK_LOGGING
)))
if
(
ck
::
EnvIsEnabled
(
CK_
ENV
(
CK_LOGGING
)))
{
{
std
::
cout
<<
"block_start: "
<<
arg
.
block_start_
<<
"
\n
"
std
::
cout
<<
"block_start: "
<<
arg
.
block_start_
<<
"
\n
"
<<
"block_end: "
<<
arg
.
block_end_
<<
"
\n
"
<<
"block_end: "
<<
arg
.
block_end_
<<
"
\n
"
...
@@ -774,13 +775,13 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage
...
@@ -774,13 +775,13 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage
dim3
(
BlockSize
),
dim3
(
BlockSize
),
0
,
0
,
cast_pointer_to_constant_address_space
(
dev_gemm_args
),
cast_pointer_to_constant_address_space
(
dev_gemm_args
),
arg
.
g
roup_count_
,
arg
.
g
emm_kernel_args_
.
size
()
,
arg
.
a_element_op_
,
arg
.
a_element_op_
,
arg
.
b_element_op_
,
arg
.
b_element_op_
,
PassThrough
{});
PassThrough
{});
// Elementwise kernels
// Elementwise kernels
for
(
in
t
i
=
0
;
i
<
arg
.
g
roup_count_
;
++
i
)
for
(
size_
t
i
=
0
;
i
<
arg
.
g
emm_kernel_args_
.
size
()
;
++
i
)
{
{
time
+=
launch_and_time_kernel
(
time
+=
launch_and_time_kernel
(
stream_config
,
stream_config
,
...
@@ -818,7 +819,7 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage
...
@@ -818,7 +819,7 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage
if
((
ck
::
type_convert
<
ck
::
index_t
>
(
arg
.
gemm_kernel_args_
.
size
())
+
if
((
ck
::
type_convert
<
ck
::
index_t
>
(
arg
.
gemm_kernel_args_
.
size
())
+
arg
.
skipped_group_count_
)
!=
arg
.
group_count_
)
arg
.
skipped_group_count_
)
!=
arg
.
group_count_
)
{
{
if
(
ck
::
EnvIsEnabled
(
ENV
(
CK_LOGGING
)))
if
(
ck
::
EnvIsEnabled
(
CK_
ENV
(
CK_LOGGING
)))
{
{
std
::
cout
<<
"The group count is not equal to sum of skipped groups "
std
::
cout
<<
"The group count is not equal to sum of skipped groups "
"and kernel args size!"
"and kernel args size!"
...
@@ -835,7 +836,7 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage
...
@@ -835,7 +836,7 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage
bool
group_arg_valid
=
GridwiseGemm
::
CheckValidity
(
gemm_arg
);
bool
group_arg_valid
=
GridwiseGemm
::
CheckValidity
(
gemm_arg
);
if
(
not
group_arg_valid
)
if
(
not
group_arg_valid
)
{
{
if
(
ck
::
EnvIsEnabled
(
ENV
(
CK_LOGGING
)))
if
(
ck
::
EnvIsEnabled
(
CK_
ENV
(
CK_LOGGING
)))
{
{
std
::
cout
<<
"["
<<
__func__
<<
"] group id: "
<<
i
std
::
cout
<<
"["
<<
__func__
<<
"] group id: "
<<
i
<<
" has invalid GridwiseGemm settings!"
<<
std
::
endl
;
<<
" has invalid GridwiseGemm settings!"
<<
std
::
endl
;
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_xdl_cshuffle_tile_loop.hpp
View file @
a6ef5c39
...
@@ -620,7 +620,7 @@ struct DeviceGroupedGemmMultipleDXdlCShuffleTileLoop
...
@@ -620,7 +620,7 @@ struct DeviceGroupedGemmMultipleDXdlCShuffleTileLoop
GridwiseGemm
::
template
CheckTensorTransfersValidity
<
ALayout
,
BLayout
,
ELayout
>(
GridwiseGemm
::
template
CheckTensorTransfersValidity
<
ALayout
,
BLayout
,
ELayout
>(
M
,
N
,
K
)))
M
,
N
,
K
)))
{
{
if
(
ck
::
EnvIsEnabled
(
ENV
(
CK_LOGGING
)))
if
(
ck
::
EnvIsEnabled
(
CK_
ENV
(
CK_LOGGING
)))
{
{
std
::
cout
<<
"The provided GEMM problem size (M,N,K) ["
<<
M
<<
","
<<
N
<<
","
std
::
cout
<<
"The provided GEMM problem size (M,N,K) ["
<<
M
<<
","
<<
N
<<
","
<<
K
<<
"] are not supported by current template parameters!"
<<
K
<<
"] are not supported by current template parameters!"
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl.hpp
View file @
a6ef5c39
...
@@ -514,7 +514,7 @@ struct DeviceGroupedGemm_Xdl : public DeviceGroupedGemm<ALayout,
...
@@ -514,7 +514,7 @@ struct DeviceGroupedGemm_Xdl : public DeviceGroupedGemm<ALayout,
for
(
std
::
size_t
i
=
0
;
i
<
arg
.
gemm_desc_kernel_arg_
.
size
();
i
++
)
for
(
std
::
size_t
i
=
0
;
i
<
arg
.
gemm_desc_kernel_arg_
.
size
();
i
++
)
{
{
if
(
ck
::
EnvIsEnabled
(
ENV
(
CK_LOGGING
)))
if
(
ck
::
EnvIsEnabled
(
CK_
ENV
(
CK_LOGGING
)))
{
{
std
::
cout
<<
"group: "
<<
i
<<
" arg.a_grid_desc_ak0_m_ak1_{"
std
::
cout
<<
"group: "
<<
i
<<
" arg.a_grid_desc_ak0_m_ak1_{"
<<
arg
.
gemm_desc_kernel_arg_
[
i
].
a_grid_desc_ak0_m_ak1_
.
GetLength
(
I0
)
<<
arg
.
gemm_desc_kernel_arg_
[
i
].
a_grid_desc_ak0_m_ak1_
.
GetLength
(
I0
)
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_splitk_cshuffle.hpp
View file @
a6ef5c39
...
@@ -529,7 +529,7 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
...
@@ -529,7 +529,7 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
if
((
ck
::
type_convert
<
ck
::
index_t
>
(
arg
.
gemm_kernel_args_
.
size
())
+
if
((
ck
::
type_convert
<
ck
::
index_t
>
(
arg
.
gemm_kernel_args_
.
size
())
+
arg
.
skipped_group_count_
)
!=
arg
.
group_count_
)
arg
.
skipped_group_count_
)
!=
arg
.
group_count_
)
{
{
if
(
ck
::
EnvIsEnabled
(
ENV
(
CK_LOGGING
)))
if
(
ck
::
EnvIsEnabled
(
CK_
ENV
(
CK_LOGGING
)))
{
{
std
::
cout
<<
"The group count is not equal to sum of skipped groups "
std
::
cout
<<
"The group count is not equal to sum of skipped groups "
"and kernel args size!"
"and kernel args size!"
...
@@ -545,7 +545,7 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
...
@@ -545,7 +545,7 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
bool
group_arg_valid
=
GridwiseGemm
::
CheckValidity
(
a
);
bool
group_arg_valid
=
GridwiseGemm
::
CheckValidity
(
a
);
if
(
not
group_arg_valid
)
if
(
not
group_arg_valid
)
{
{
if
(
ck
::
EnvIsEnabled
(
ENV
(
CK_LOGGING
)))
if
(
ck
::
EnvIsEnabled
(
CK_
ENV
(
CK_LOGGING
)))
{
{
std
::
cout
<<
"["
<<
__func__
<<
"] group id: "
<<
i
std
::
cout
<<
"["
<<
__func__
<<
"] group id: "
<<
i
<<
" has invalid GridwiseGemm settings!"
<<
std
::
endl
;
<<
" has invalid GridwiseGemm settings!"
<<
std
::
endl
;
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3.hpp
View file @
a6ef5c39
...
@@ -935,7 +935,7 @@ struct GridwiseGemm_xdl_cshuffle_v3
...
@@ -935,7 +935,7 @@ struct GridwiseGemm_xdl_cshuffle_v3
{
{
if
(
!
(
karg
.
M
%
MPerBlock
==
0
))
if
(
!
(
karg
.
M
%
MPerBlock
==
0
))
{
{
if
(
ck
::
EnvIsEnabled
(
ENV
(
CK_LOGGING
)))
if
(
ck
::
EnvIsEnabled
(
CK_
ENV
(
CK_LOGGING
)))
{
{
std
::
cout
<<
"Arg M value is not a multiple of MPerBlock! M: "
<<
karg
.
M
<<
" "
std
::
cout
<<
"Arg M value is not a multiple of MPerBlock! M: "
<<
karg
.
M
<<
" "
<<
__FILE__
<<
":"
<<
__LINE__
<<
", in function: "
<<
__func__
<<
__FILE__
<<
":"
<<
__LINE__
<<
", in function: "
<<
__func__
...
@@ -952,7 +952,7 @@ struct GridwiseGemm_xdl_cshuffle_v3
...
@@ -952,7 +952,7 @@ struct GridwiseGemm_xdl_cshuffle_v3
{
{
if
(
!
(
karg
.
N
%
NPerBlock
==
0
))
if
(
!
(
karg
.
N
%
NPerBlock
==
0
))
{
{
if
(
ck
::
EnvIsEnabled
(
ENV
(
CK_LOGGING
)))
if
(
ck
::
EnvIsEnabled
(
CK_
ENV
(
CK_LOGGING
)))
{
{
std
::
cout
<<
"Arg N value is not a multiple of NPerBlock! N: "
<<
karg
.
N
<<
" "
std
::
cout
<<
"Arg N value is not a multiple of NPerBlock! N: "
<<
karg
.
N
<<
" "
<<
__FILE__
<<
":"
<<
__LINE__
<<
", in function: "
<<
__func__
<<
__FILE__
<<
":"
<<
__LINE__
<<
", in function: "
<<
__func__
...
@@ -971,7 +971,7 @@ struct GridwiseGemm_xdl_cshuffle_v3
...
@@ -971,7 +971,7 @@ struct GridwiseGemm_xdl_cshuffle_v3
auto
K_t
=
karg
.
KBatch
*
KPerBlock
;
auto
K_t
=
karg
.
KBatch
*
KPerBlock
;
if
(
!
(
karg
.
K
%
K_t
==
0
))
if
(
!
(
karg
.
K
%
K_t
==
0
))
{
{
if
(
ck
::
EnvIsEnabled
(
ENV
(
CK_LOGGING
)))
if
(
ck
::
EnvIsEnabled
(
CK_
ENV
(
CK_LOGGING
)))
{
{
std
::
cout
<<
"Arg K value is not a multiple of K_Batch * K0PerBlock * K1! K: "
std
::
cout
<<
"Arg K value is not a multiple of K_Batch * K0PerBlock * K1! K: "
<<
karg
.
K
<<
" "
<<
__FILE__
<<
":"
<<
__LINE__
<<
karg
.
K
<<
" "
<<
__FILE__
<<
":"
<<
__LINE__
...
@@ -995,7 +995,7 @@ struct GridwiseGemm_xdl_cshuffle_v3
...
@@ -995,7 +995,7 @@ struct GridwiseGemm_xdl_cshuffle_v3
{
{
if
(
karg
.
K
%
ABlockTransferSrcScalarPerVector
!=
0
)
if
(
karg
.
K
%
ABlockTransferSrcScalarPerVector
!=
0
)
{
{
if
(
ck
::
EnvIsEnabled
(
ENV
(
CK_LOGGING
)))
if
(
ck
::
EnvIsEnabled
(
CK_
ENV
(
CK_LOGGING
)))
{
{
std
::
cout
<<
"Arg K ("
<<
karg
.
K
std
::
cout
<<
"Arg K ("
<<
karg
.
K
<<
") value is not a multiple of ABlockTransferSrcScalarPerVector ("
<<
") value is not a multiple of ABlockTransferSrcScalarPerVector ("
...
@@ -1009,7 +1009,7 @@ struct GridwiseGemm_xdl_cshuffle_v3
...
@@ -1009,7 +1009,7 @@ struct GridwiseGemm_xdl_cshuffle_v3
{
{
if
(
karg
.
M
%
ABlockTransferSrcScalarPerVector
!=
0
)
if
(
karg
.
M
%
ABlockTransferSrcScalarPerVector
!=
0
)
{
{
if
(
ck
::
EnvIsEnabled
(
ENV
(
CK_LOGGING
)))
if
(
ck
::
EnvIsEnabled
(
CK_
ENV
(
CK_LOGGING
)))
{
{
std
::
cout
<<
"Arg M ("
<<
karg
.
M
std
::
cout
<<
"Arg M ("
<<
karg
.
M
<<
") value is not a multiple of ABlockTransferSrcScalarPerVector ("
<<
") value is not a multiple of ABlockTransferSrcScalarPerVector ("
...
@@ -1024,7 +1024,7 @@ struct GridwiseGemm_xdl_cshuffle_v3
...
@@ -1024,7 +1024,7 @@ struct GridwiseGemm_xdl_cshuffle_v3
{
{
if
(
karg
.
N
%
BBlockTransferSrcScalarPerVector
!=
0
)
if
(
karg
.
N
%
BBlockTransferSrcScalarPerVector
!=
0
)
{
{
if
(
ck
::
EnvIsEnabled
(
ENV
(
CK_LOGGING
)))
if
(
ck
::
EnvIsEnabled
(
CK_
ENV
(
CK_LOGGING
)))
{
{
std
::
cout
<<
"Arg N ("
<<
karg
.
N
std
::
cout
<<
"Arg N ("
<<
karg
.
N
<<
") value is not a multiple of BBlockTransferSrcScalarPerVector ("
<<
") value is not a multiple of BBlockTransferSrcScalarPerVector ("
...
@@ -1038,7 +1038,7 @@ struct GridwiseGemm_xdl_cshuffle_v3
...
@@ -1038,7 +1038,7 @@ struct GridwiseGemm_xdl_cshuffle_v3
{
{
if
(
karg
.
K
%
BBlockTransferSrcScalarPerVector
!=
0
)
if
(
karg
.
K
%
BBlockTransferSrcScalarPerVector
!=
0
)
{
{
if
(
ck
::
EnvIsEnabled
(
ENV
(
CK_LOGGING
)))
if
(
ck
::
EnvIsEnabled
(
CK_
ENV
(
CK_LOGGING
)))
{
{
std
::
cout
<<
"Arg K ("
<<
karg
.
K
std
::
cout
<<
"Arg K ("
<<
karg
.
K
<<
") value is not a multiple of BBlockTransferSrcScalarPerVector ("
<<
") value is not a multiple of BBlockTransferSrcScalarPerVector ("
...
@@ -1053,7 +1053,7 @@ struct GridwiseGemm_xdl_cshuffle_v3
...
@@ -1053,7 +1053,7 @@ struct GridwiseGemm_xdl_cshuffle_v3
{
{
if
(
karg
.
N
%
CShuffleBlockTransferScalarPerVector_NPerBlock
!=
0
)
if
(
karg
.
N
%
CShuffleBlockTransferScalarPerVector_NPerBlock
!=
0
)
{
{
if
(
ck
::
EnvIsEnabled
(
ENV
(
CK_LOGGING
)))
if
(
ck
::
EnvIsEnabled
(
CK_
ENV
(
CK_LOGGING
)))
{
{
std
::
cout
<<
"Arg N ("
<<
karg
.
N
std
::
cout
<<
"Arg N ("
<<
karg
.
N
<<
") value is not a multiple of "
<<
") value is not a multiple of "
...
@@ -1069,7 +1069,7 @@ struct GridwiseGemm_xdl_cshuffle_v3
...
@@ -1069,7 +1069,7 @@ struct GridwiseGemm_xdl_cshuffle_v3
{
{
if
(
karg
.
M
%
CShuffleBlockTransferScalarPerVector_NPerBlock
!=
0
)
if
(
karg
.
M
%
CShuffleBlockTransferScalarPerVector_NPerBlock
!=
0
)
{
{
if
(
ck
::
EnvIsEnabled
(
ENV
(
CK_LOGGING
)))
if
(
ck
::
EnvIsEnabled
(
CK_
ENV
(
CK_LOGGING
)))
{
{
std
::
cout
<<
"Arg M ("
<<
karg
.
M
std
::
cout
<<
"Arg M ("
<<
karg
.
M
<<
") value is not a multiple of "
<<
") value is not a multiple of "
...
@@ -1084,7 +1084,7 @@ struct GridwiseGemm_xdl_cshuffle_v3
...
@@ -1084,7 +1084,7 @@ struct GridwiseGemm_xdl_cshuffle_v3
if
constexpr
(
is_same
<
remove_cvref_t
<
CDataType
>
,
bhalf_t
>::
value
)
if
constexpr
(
is_same
<
remove_cvref_t
<
CDataType
>
,
bhalf_t
>::
value
)
{
{
if
(
ck
::
EnvIsEnabled
(
ENV
(
CK_LOGGING
)))
if
(
ck
::
EnvIsEnabled
(
CK_
ENV
(
CK_LOGGING
)))
{
{
std
::
cout
<<
" KBatch: "
<<
karg
.
KBatch
<<
" > 1 is not support yet"
<<
__FILE__
std
::
cout
<<
" KBatch: "
<<
karg
.
KBatch
<<
" > 1 is not support yet"
<<
__FILE__
<<
":"
<<
__LINE__
<<
", in function: "
<<
__func__
<<
std
::
endl
;
<<
":"
<<
__LINE__
<<
", in function: "
<<
__func__
<<
std
::
endl
;
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_abd.hpp
View file @
a6ef5c39
...
@@ -1113,7 +1113,7 @@ struct GridwiseGemm_xdl_cshuffle_v3
...
@@ -1113,7 +1113,7 @@ struct GridwiseGemm_xdl_cshuffle_v3
{
{
if
(
!
(
karg
.
M
%
MPerBlock
==
0
))
if
(
!
(
karg
.
M
%
MPerBlock
==
0
))
{
{
if
(
ck
::
EnvIsEnabled
(
ENV
(
CK_LOGGING
)))
if
(
ck
::
EnvIsEnabled
(
CK_
ENV
(
CK_LOGGING
)))
{
{
std
::
cout
<<
"Arg M value is not a multiple of MPerBlock! M: "
<<
karg
.
M
<<
" "
std
::
cout
<<
"Arg M value is not a multiple of MPerBlock! M: "
<<
karg
.
M
<<
" "
<<
__FILE__
<<
":"
<<
__LINE__
<<
", in function: "
<<
__func__
<<
__FILE__
<<
":"
<<
__LINE__
<<
", in function: "
<<
__func__
...
@@ -1130,7 +1130,7 @@ struct GridwiseGemm_xdl_cshuffle_v3
...
@@ -1130,7 +1130,7 @@ struct GridwiseGemm_xdl_cshuffle_v3
{
{
if
(
!
(
karg
.
N
%
NPerBlock
==
0
))
if
(
!
(
karg
.
N
%
NPerBlock
==
0
))
{
{
if
(
ck
::
EnvIsEnabled
(
ENV
(
CK_LOGGING
)))
if
(
ck
::
EnvIsEnabled
(
CK_
ENV
(
CK_LOGGING
)))
{
{
std
::
cout
<<
"Arg N value is not a multiple of NPerBlock! N: "
<<
karg
.
N
<<
" "
std
::
cout
<<
"Arg N value is not a multiple of NPerBlock! N: "
<<
karg
.
N
<<
" "
<<
__FILE__
<<
":"
<<
__LINE__
<<
", in function: "
<<
__func__
<<
__FILE__
<<
":"
<<
__LINE__
<<
", in function: "
<<
__func__
...
@@ -1149,7 +1149,7 @@ struct GridwiseGemm_xdl_cshuffle_v3
...
@@ -1149,7 +1149,7 @@ struct GridwiseGemm_xdl_cshuffle_v3
auto
K_t
=
karg
.
KBatch
*
KPerBlock
;
auto
K_t
=
karg
.
KBatch
*
KPerBlock
;
if
(
!
(
karg
.
K
%
K_t
==
0
))
if
(
!
(
karg
.
K
%
K_t
==
0
))
{
{
if
(
ck
::
EnvIsEnabled
(
ENV
(
CK_LOGGING
)))
if
(
ck
::
EnvIsEnabled
(
CK_
ENV
(
CK_LOGGING
)))
{
{
std
::
cout
<<
"Arg K value is not a multiple of K_Batch * K0PerBlock * K1! K: "
std
::
cout
<<
"Arg K value is not a multiple of K_Batch * K0PerBlock * K1! K: "
<<
karg
.
K
<<
" "
<<
__FILE__
<<
":"
<<
__LINE__
<<
karg
.
K
<<
" "
<<
__FILE__
<<
":"
<<
__LINE__
...
@@ -1173,7 +1173,7 @@ struct GridwiseGemm_xdl_cshuffle_v3
...
@@ -1173,7 +1173,7 @@ struct GridwiseGemm_xdl_cshuffle_v3
{
{
if
(
karg
.
K
%
ABlockTransferSrcScalarPerVector
!=
0
)
if
(
karg
.
K
%
ABlockTransferSrcScalarPerVector
!=
0
)
{
{
if
(
ck
::
EnvIsEnabled
(
ENV
(
CK_LOGGING
)))
if
(
ck
::
EnvIsEnabled
(
CK_
ENV
(
CK_LOGGING
)))
{
{
std
::
cout
<<
"Arg K ("
<<
karg
.
K
std
::
cout
<<
"Arg K ("
<<
karg
.
K
<<
") value is not a multiple of ABlockTransferSrcScalarPerVector ("
<<
") value is not a multiple of ABlockTransferSrcScalarPerVector ("
...
@@ -1187,7 +1187,7 @@ struct GridwiseGemm_xdl_cshuffle_v3
...
@@ -1187,7 +1187,7 @@ struct GridwiseGemm_xdl_cshuffle_v3
{
{
if
(
karg
.
M
%
ABlockTransferSrcScalarPerVector
!=
0
)
if
(
karg
.
M
%
ABlockTransferSrcScalarPerVector
!=
0
)
{
{
if
(
ck
::
EnvIsEnabled
(
ENV
(
CK_LOGGING
)))
if
(
ck
::
EnvIsEnabled
(
CK_
ENV
(
CK_LOGGING
)))
{
{
std
::
cout
<<
"Arg M ("
<<
karg
.
M
std
::
cout
<<
"Arg M ("
<<
karg
.
M
<<
") value is not a multiple of ABlockTransferSrcScalarPerVector ("
<<
") value is not a multiple of ABlockTransferSrcScalarPerVector ("
...
@@ -1202,7 +1202,7 @@ struct GridwiseGemm_xdl_cshuffle_v3
...
@@ -1202,7 +1202,7 @@ struct GridwiseGemm_xdl_cshuffle_v3
{
{
if
(
karg
.
N
%
BBlockTransferSrcScalarPerVector
!=
0
)
if
(
karg
.
N
%
BBlockTransferSrcScalarPerVector
!=
0
)
{
{
if
(
ck
::
EnvIsEnabled
(
ENV
(
CK_LOGGING
)))
if
(
ck
::
EnvIsEnabled
(
CK_
ENV
(
CK_LOGGING
)))
{
{
std
::
cout
<<
"Arg N ("
<<
karg
.
N
std
::
cout
<<
"Arg N ("
<<
karg
.
N
<<
") value is not a multiple of BBlockTransferSrcScalarPerVector ("
<<
") value is not a multiple of BBlockTransferSrcScalarPerVector ("
...
@@ -1216,7 +1216,7 @@ struct GridwiseGemm_xdl_cshuffle_v3
...
@@ -1216,7 +1216,7 @@ struct GridwiseGemm_xdl_cshuffle_v3
{
{
if
(
karg
.
K
%
BBlockTransferSrcScalarPerVector
!=
0
)
if
(
karg
.
K
%
BBlockTransferSrcScalarPerVector
!=
0
)
{
{
if
(
ck
::
EnvIsEnabled
(
ENV
(
CK_LOGGING
)))
if
(
ck
::
EnvIsEnabled
(
CK_
ENV
(
CK_LOGGING
)))
{
{
std
::
cout
<<
"Arg K ("
<<
karg
.
K
std
::
cout
<<
"Arg K ("
<<
karg
.
K
<<
") value is not a multiple of BBlockTransferSrcScalarPerVector ("
<<
") value is not a multiple of BBlockTransferSrcScalarPerVector ("
...
@@ -1231,7 +1231,7 @@ struct GridwiseGemm_xdl_cshuffle_v3
...
@@ -1231,7 +1231,7 @@ struct GridwiseGemm_xdl_cshuffle_v3
{
{
if
(
karg
.
N
%
CShuffleBlockTransferScalarPerVector_NPerBlock
!=
0
)
if
(
karg
.
N
%
CShuffleBlockTransferScalarPerVector_NPerBlock
!=
0
)
{
{
if
(
ck
::
EnvIsEnabled
(
ENV
(
CK_LOGGING
)))
if
(
ck
::
EnvIsEnabled
(
CK_
ENV
(
CK_LOGGING
)))
{
{
std
::
cout
<<
"Arg N ("
<<
karg
.
N
std
::
cout
<<
"Arg N ("
<<
karg
.
N
<<
") value is not a multiple of "
<<
") value is not a multiple of "
...
@@ -1247,7 +1247,7 @@ struct GridwiseGemm_xdl_cshuffle_v3
...
@@ -1247,7 +1247,7 @@ struct GridwiseGemm_xdl_cshuffle_v3
{
{
if
(
karg
.
M
%
CShuffleBlockTransferScalarPerVector_NPerBlock
!=
0
)
if
(
karg
.
M
%
CShuffleBlockTransferScalarPerVector_NPerBlock
!=
0
)
{
{
if
(
ck
::
EnvIsEnabled
(
ENV
(
CK_LOGGING
)))
if
(
ck
::
EnvIsEnabled
(
CK_
ENV
(
CK_LOGGING
)))
{
{
std
::
cout
<<
"Arg M ("
<<
karg
.
M
std
::
cout
<<
"Arg M ("
<<
karg
.
M
<<
") value is not a multiple of "
<<
") value is not a multiple of "
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4r2.hpp
View file @
a6ef5c39
...
@@ -446,7 +446,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
...
@@ -446,7 +446,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
{
{
if
(
!
(
karg
.
M
%
MPerBlock
==
0
))
if
(
!
(
karg
.
M
%
MPerBlock
==
0
))
{
{
if
(
ck
::
EnvIsEnabled
(
ENV
(
CK_LOGGING
)))
if
(
ck
::
EnvIsEnabled
(
CK_
ENV
(
CK_LOGGING
)))
{
{
std
::
cout
<<
"Arg M value is not a multiple of MPerBlock! M: "
<<
karg
.
M
<<
" "
std
::
cout
<<
"Arg M value is not a multiple of MPerBlock! M: "
<<
karg
.
M
<<
" "
<<
__FILE__
<<
":"
<<
__LINE__
<<
", in function: "
<<
__func__
<<
__FILE__
<<
":"
<<
__LINE__
<<
", in function: "
<<
__func__
...
@@ -463,7 +463,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
...
@@ -463,7 +463,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
{
{
if
(
!
(
karg
.
N
%
NPerBlock
==
0
))
if
(
!
(
karg
.
N
%
NPerBlock
==
0
))
{
{
if
(
ck
::
EnvIsEnabled
(
ENV
(
CK_LOGGING
)))
if
(
ck
::
EnvIsEnabled
(
CK_
ENV
(
CK_LOGGING
)))
{
{
std
::
cout
<<
"Arg N value is not a multiple of NPerBlock! N: "
<<
karg
.
N
<<
" "
std
::
cout
<<
"Arg N value is not a multiple of NPerBlock! N: "
<<
karg
.
N
<<
" "
<<
__FILE__
<<
":"
<<
__LINE__
<<
", in function: "
<<
__func__
<<
__FILE__
<<
":"
<<
__LINE__
<<
", in function: "
<<
__func__
...
@@ -482,7 +482,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
...
@@ -482,7 +482,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
auto
K_t
=
karg
.
k_batch
*
K0PerBlock
*
K1
;
auto
K_t
=
karg
.
k_batch
*
K0PerBlock
*
K1
;
if
(
!
(
karg
.
K
%
K_t
==
0
))
if
(
!
(
karg
.
K
%
K_t
==
0
))
{
{
if
(
ck
::
EnvIsEnabled
(
ENV
(
CK_LOGGING
)))
if
(
ck
::
EnvIsEnabled
(
CK_
ENV
(
CK_LOGGING
)))
{
{
std
::
cout
<<
"Arg K value is not a multiple of K_Batch * K0PerBlock * K1! K: "
std
::
cout
<<
"Arg K value is not a multiple of K_Batch * K0PerBlock * K1! K: "
<<
karg
.
K
<<
" "
<<
__FILE__
<<
":"
<<
__LINE__
<<
karg
.
K
<<
" "
<<
__FILE__
<<
":"
<<
__LINE__
...
@@ -496,7 +496,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
...
@@ -496,7 +496,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
{
{
if
(
karg
.
K
%
ABlockTransferSrcScalarPerVector
!=
0
)
if
(
karg
.
K
%
ABlockTransferSrcScalarPerVector
!=
0
)
{
{
if
(
ck
::
EnvIsEnabled
(
ENV
(
CK_LOGGING
)))
if
(
ck
::
EnvIsEnabled
(
CK_
ENV
(
CK_LOGGING
)))
{
{
std
::
cout
<<
"Arg K ("
<<
karg
.
K
std
::
cout
<<
"Arg K ("
<<
karg
.
K
<<
") value is not a multiple of ABlockTransferSrcScalarPerVector ("
<<
") value is not a multiple of ABlockTransferSrcScalarPerVector ("
...
@@ -510,7 +510,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
...
@@ -510,7 +510,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
{
{
if
(
karg
.
M
%
ABlockTransferSrcScalarPerVector
!=
0
)
if
(
karg
.
M
%
ABlockTransferSrcScalarPerVector
!=
0
)
{
{
if
(
ck
::
EnvIsEnabled
(
ENV
(
CK_LOGGING
)))
if
(
ck
::
EnvIsEnabled
(
CK_
ENV
(
CK_LOGGING
)))
{
{
std
::
cout
<<
"Arg M ("
<<
karg
.
M
std
::
cout
<<
"Arg M ("
<<
karg
.
M
<<
") value is not a multiple of ABlockTransferSrcScalarPerVector ("
<<
") value is not a multiple of ABlockTransferSrcScalarPerVector ("
...
@@ -525,7 +525,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
...
@@ -525,7 +525,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
{
{
if
(
karg
.
N
%
BBlockTransferSrcScalarPerVector
!=
0
)
if
(
karg
.
N
%
BBlockTransferSrcScalarPerVector
!=
0
)
{
{
if
(
ck
::
EnvIsEnabled
(
ENV
(
CK_LOGGING
)))
if
(
ck
::
EnvIsEnabled
(
CK_
ENV
(
CK_LOGGING
)))
{
{
std
::
cout
<<
"Arg N ("
<<
karg
.
N
std
::
cout
<<
"Arg N ("
<<
karg
.
N
<<
") value is not a multiple of BBlockTransferSrcScalarPerVector ("
<<
") value is not a multiple of BBlockTransferSrcScalarPerVector ("
...
@@ -539,7 +539,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
...
@@ -539,7 +539,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
{
{
if
(
karg
.
K
%
BBlockTransferSrcScalarPerVector
!=
0
)
if
(
karg
.
K
%
BBlockTransferSrcScalarPerVector
!=
0
)
{
{
if
(
ck
::
EnvIsEnabled
(
ENV
(
CK_LOGGING
)))
if
(
ck
::
EnvIsEnabled
(
CK_
ENV
(
CK_LOGGING
)))
{
{
std
::
cout
<<
"Arg K ("
<<
karg
.
K
std
::
cout
<<
"Arg K ("
<<
karg
.
K
<<
") value is not a multiple of BBlockTransferSrcScalarPerVector ("
<<
") value is not a multiple of BBlockTransferSrcScalarPerVector ("
...
@@ -554,7 +554,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
...
@@ -554,7 +554,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
{
{
if
(
karg
.
N
%
CBlockTransferScalarPerVector_NWaveNPerXDL
!=
0
)
if
(
karg
.
N
%
CBlockTransferScalarPerVector_NWaveNPerXDL
!=
0
)
{
{
if
(
ck
::
EnvIsEnabled
(
ENV
(
CK_LOGGING
)))
if
(
ck
::
EnvIsEnabled
(
CK_
ENV
(
CK_LOGGING
)))
{
{
std
::
cout
<<
"Arg N ("
<<
karg
.
N
std
::
cout
<<
"Arg N ("
<<
karg
.
N
<<
") value is not a multiple of "
<<
") value is not a multiple of "
...
@@ -569,7 +569,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
...
@@ -569,7 +569,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
{
{
if
(
karg
.
M
%
CBlockTransferScalarPerVector_NWaveNPerXDL
!=
0
)
if
(
karg
.
M
%
CBlockTransferScalarPerVector_NWaveNPerXDL
!=
0
)
{
{
if
(
ck
::
EnvIsEnabled
(
ENV
(
CK_LOGGING
)))
if
(
ck
::
EnvIsEnabled
(
CK_
ENV
(
CK_LOGGING
)))
{
{
std
::
cout
<<
"Arg M ("
<<
karg
.
M
std
::
cout
<<
"Arg M ("
<<
karg
.
M
<<
") value is not a multiple of "
<<
") value is not a multiple of "
...
@@ -584,7 +584,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
...
@@ -584,7 +584,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
const
auto
num_k_loop
=
karg
.
K0Padded
/
K0PerBlock
;
const
auto
num_k_loop
=
karg
.
K0Padded
/
K0PerBlock
;
if
(
!
GridwiseGemmPipe
::
IsSupported
(
num_k_loop
))
if
(
!
GridwiseGemmPipe
::
IsSupported
(
num_k_loop
))
{
{
if
(
ck
::
EnvIsEnabled
(
ENV
(
CK_LOGGING
)))
if
(
ck
::
EnvIsEnabled
(
CK_
ENV
(
CK_LOGGING
)))
{
{
std
::
cout
<<
"The number of k loops ("
<<
num_k_loop
std
::
cout
<<
"The number of k loops ("
<<
num_k_loop
<<
") value is not supported by GridwiseGemm Pipeline."
<<
") value is not supported by GridwiseGemm Pipeline."
...
...
include/ck/utility/env.hpp
View file @
a6ef5c39
...
@@ -124,7 +124,7 @@ struct EnvVar
...
@@ -124,7 +124,7 @@ struct EnvVar
#define CK_DECLARE_ENV_VAR_STR(name) CK_DECLARE_ENV_VAR(name, std::string, "")
#define CK_DECLARE_ENV_VAR_STR(name) CK_DECLARE_ENV_VAR(name, std::string, "")
#define ENV(name) \
#define
CK_
ENV(name) \
ck::env::name {}
ck::env::name {}
template
<
class
EnvVar
>
template
<
class
EnvVar
>
...
...
include/ck_tile/core/numeric/half.hpp
View file @
a6ef5c39
...
@@ -129,8 +129,8 @@ constexpr double fp16_to_double_hip(const fp16_hip_t& x)
...
@@ -129,8 +129,8 @@ constexpr double fp16_to_double_hip(const fp16_hip_t& x)
CK_TILE_HOST_DEVICE
CK_TILE_HOST_DEVICE
constexpr
fp16_hip_t
float_to_fp16_hip
(
const
float
&
x
)
constexpr
fp16_hip_t
float_to_fp16_hip
(
const
float
&
x
)
{
{
return
__float2half
(
x
);
//
return __float2half(x);
//
return static_cast<fp16_hip_t>(x);
return
static_cast
<
fp16_hip_t
>
(
x
);
}
}
CK_TILE_HOST_DEVICE
CK_TILE_HOST_DEVICE
...
...
include/ck_tile/core/numeric/integral_constant.hpp
View file @
a6ef5c39
...
@@ -56,7 +56,6 @@ CK_TILE_LEFT_UNARY_OP(+)
...
@@ -56,7 +56,6 @@ CK_TILE_LEFT_UNARY_OP(+)
CK_TILE_LEFT_UNARY_OP
(
-
)
CK_TILE_LEFT_UNARY_OP
(
-
)
CK_TILE_LEFT_UNARY_OP
(
~
)
CK_TILE_LEFT_UNARY_OP
(
~
)
CK_TILE_LEFT_UNARY_OP
(
!
)
CK_TILE_LEFT_UNARY_OP
(
!
)
CK_TILE_LEFT_UNARY_OP
(
*
)
CK_TILE_BINARY_OP
(
+
)
CK_TILE_BINARY_OP
(
+
)
CK_TILE_BINARY_OP
(
-
)
CK_TILE_BINARY_OP
(
-
)
...
...
profiler/include/profiler/profile_grouped_gemm_fixed_nk_impl.hpp
View file @
a6ef5c39
...
@@ -88,7 +88,7 @@ bool profile_grouped_gemm_fixed_nk_impl(int do_verification,
...
@@ -88,7 +88,7 @@ bool profile_grouped_gemm_fixed_nk_impl(int do_verification,
c_m_n_host_results
.
push_back
(
c_m_n_host_results
.
push_back
(
Tensor
<
CDataType
>
(
f_host_tensor_descriptor
(
Ms
[
i
],
Ns
[
i
],
StrideCs
[
i
],
CLayout
{})));
Tensor
<
CDataType
>
(
f_host_tensor_descriptor
(
Ms
[
i
],
Ns
[
i
],
StrideCs
[
i
],
CLayout
{})));
if
(
ck
::
EnvIsEnabled
(
ENV
(
CK_LOGGING
)))
if
(
ck
::
EnvIsEnabled
(
CK_
ENV
(
CK_LOGGING
)))
{
{
std
::
cout
<<
"group: "
<<
i
<<
" a_m_k["
<<
i
<<
"]:"
<<
a_m_k
[
i
].
mDesc
<<
", b_k_n["
std
::
cout
<<
"group: "
<<
i
<<
" a_m_k["
<<
i
<<
"]:"
<<
a_m_k
[
i
].
mDesc
<<
", b_k_n["
<<
i
<<
"]:"
<<
b_k_n
[
i
].
mDesc
<<
", c_m_n_device_results["
<<
i
<<
i
<<
"]:"
<<
b_k_n
[
i
].
mDesc
<<
", c_m_n_device_results["
<<
i
...
...
profiler/include/profiler/profile_grouped_gemm_impl.hpp
View file @
a6ef5c39
...
@@ -87,7 +87,7 @@ bool profile_grouped_gemm_impl(int do_verification,
...
@@ -87,7 +87,7 @@ bool profile_grouped_gemm_impl(int do_verification,
c_m_n_host_results
.
push_back
(
c_m_n_host_results
.
push_back
(
Tensor
<
CDataType
>
(
f_host_tensor_descriptor
(
Ms
[
i
],
Ns
[
i
],
StrideCs
[
i
],
CLayout
{})));
Tensor
<
CDataType
>
(
f_host_tensor_descriptor
(
Ms
[
i
],
Ns
[
i
],
StrideCs
[
i
],
CLayout
{})));
if
(
ck
::
EnvIsEnabled
(
ENV
(
CK_LOGGING
)))
if
(
ck
::
EnvIsEnabled
(
CK_
ENV
(
CK_LOGGING
)))
{
{
std
::
cout
<<
"group: "
<<
i
<<
" a_m_k["
<<
i
<<
"]:"
<<
a_m_k
[
i
].
mDesc
<<
", b_k_n["
std
::
cout
<<
"group: "
<<
i
<<
" a_m_k["
<<
i
<<
"]:"
<<
a_m_k
[
i
].
mDesc
<<
", b_k_n["
<<
i
<<
"]:"
<<
b_k_n
[
i
].
mDesc
<<
", c_m_n_device_results["
<<
i
<<
i
<<
"]:"
<<
b_k_n
[
i
].
mDesc
<<
", c_m_n_device_results["
<<
i
...
...
profiler/include/profiler/profile_grouped_gemm_tile_loop_impl.hpp
View file @
a6ef5c39
...
@@ -82,7 +82,7 @@ bool profile_grouped_gemm_tile_loop_impl(int do_verification,
...
@@ -82,7 +82,7 @@ bool profile_grouped_gemm_tile_loop_impl(int do_verification,
Tensor
<
CDataType
>
(
f_host_tensor_descriptor
(
Ms
[
i
],
Ns
[
i
],
StrideCs
[
i
],
CLayout
{})));
Tensor
<
CDataType
>
(
f_host_tensor_descriptor
(
Ms
[
i
],
Ns
[
i
],
StrideCs
[
i
],
CLayout
{})));
c_m_n_host_results
.
push_back
(
c_m_n_host_results
.
push_back
(
Tensor
<
CDataType
>
(
f_host_tensor_descriptor
(
Ms
[
i
],
Ns
[
i
],
StrideCs
[
i
],
CLayout
{})));
Tensor
<
CDataType
>
(
f_host_tensor_descriptor
(
Ms
[
i
],
Ns
[
i
],
StrideCs
[
i
],
CLayout
{})));
if
(
ck
::
EnvIsEnabled
(
ENV
(
CK_LOGGING
)))
if
(
ck
::
EnvIsEnabled
(
CK_
ENV
(
CK_LOGGING
)))
{
{
std
::
cout
<<
"group: "
<<
i
<<
" a_m_k["
<<
i
<<
"]:"
<<
a_m_k
[
i
].
mDesc
<<
", b_k_n["
std
::
cout
<<
"group: "
<<
i
<<
" a_m_k["
<<
i
<<
"]:"
<<
a_m_k
[
i
].
mDesc
<<
", b_k_n["
<<
i
<<
"]:"
<<
b_k_n
[
i
].
mDesc
<<
", c_m_n_device_results["
<<
i
<<
i
<<
"]:"
<<
b_k_n
[
i
].
mDesc
<<
", c_m_n_device_results["
<<
i
...
...
profiler/include/profiler/profile_grouped_gemm_two_stage_impl.hpp
View file @
a6ef5c39
...
@@ -88,7 +88,7 @@ bool profile_grouped_gemm_two_stage_impl(int do_verification,
...
@@ -88,7 +88,7 @@ bool profile_grouped_gemm_two_stage_impl(int do_verification,
c_m_n_host_results
.
push_back
(
c_m_n_host_results
.
push_back
(
Tensor
<
CDataType
>
(
f_host_tensor_descriptor
(
Ms
[
i
],
Ns
[
i
],
StrideCs
[
i
],
CLayout
{})));
Tensor
<
CDataType
>
(
f_host_tensor_descriptor
(
Ms
[
i
],
Ns
[
i
],
StrideCs
[
i
],
CLayout
{})));
if
(
ck
::
EnvIsEnabled
(
ENV
(
CK_LOGGING
)))
if
(
ck
::
EnvIsEnabled
(
CK_
ENV
(
CK_LOGGING
)))
{
{
std
::
cout
<<
"group: "
<<
i
<<
" a_m_k["
<<
i
<<
"]:"
<<
a_m_k
[
i
].
mDesc
<<
", b_k_n["
std
::
cout
<<
"group: "
<<
i
<<
" a_m_k["
<<
i
<<
"]:"
<<
a_m_k
[
i
].
mDesc
<<
", b_k_n["
<<
i
<<
"]:"
<<
b_k_n
[
i
].
mDesc
<<
", c_m_n_device_results["
<<
i
<<
i
<<
"]:"
<<
b_k_n
[
i
].
mDesc
<<
", c_m_n_device_results["
<<
i
...
...
test/grouped_gemm/CMakeLists.txt
View file @
a6ef5c39
...
@@ -6,6 +6,12 @@ if(result EQUAL 0)
...
@@ -6,6 +6,12 @@ if(result EQUAL 0)
add_dependencies
(
test_grouped_gemm test_grouped_gemm_splitk
)
add_dependencies
(
test_grouped_gemm test_grouped_gemm_splitk
)
endif
()
endif
()
add_gtest_executable
(
test_grouped_gemm_two_stage_splitk test_grouped_gemm_two_stage_multiple_d_splitk_xdl.cpp
)
if
(
result EQUAL 0
)
target_link_libraries
(
test_grouped_gemm_two_stage_splitk PRIVATE utility device_grouped_gemm_instance
)
add_dependencies
(
test_grouped_gemm test_grouped_gemm_two_stage_splitk
)
endif
()
add_gtest_executable
(
test_grouped_gemm_interface test_grouped_gemm_interface_xdl.cpp
)
add_gtest_executable
(
test_grouped_gemm_interface test_grouped_gemm_interface_xdl.cpp
)
if
(
result EQUAL 0
)
if
(
result EQUAL 0
)
target_link_libraries
(
test_grouped_gemm_interface PRIVATE utility device_grouped_gemm_instance
)
target_link_libraries
(
test_grouped_gemm_interface PRIVATE utility device_grouped_gemm_instance
)
...
...
test/grouped_gemm/test_grouped_gemm_two_stage_multiple_d_splitk_xdl.cpp
0 → 100644
View file @
a6ef5c39
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include <tuple>
#include <vector>
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/utility/data_type.hpp"
#include "gtest/gtest.h"
#include "test_grouped_gemm_util.hpp"
using
F16
=
ck
::
half_t
;
using
BF16
=
ck
::
bhalf_t
;
using
I8
=
int8_t
;
using
Row
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
using
Col
=
ck
::
tensor_layout
::
gemm
::
ColumnMajor
;
using
RRR_F16_F16_F16
=
ck
::
test
::
TestGroupedGemmTwoStage
<
std
::
tuple
<
Row
,
Row
,
Row
,
F16
,
F16
,
F16
>>
;
using
RCR_F16_F16_F16
=
ck
::
test
::
TestGroupedGemmTwoStage
<
std
::
tuple
<
Row
,
Col
,
Row
,
F16
,
F16
,
F16
>>
;
using
RRR_F16_F16_F16_LargeK
=
ck
::
test
::
TestGroupedGemmTwoStage
<
std
::
tuple
<
Row
,
Row
,
Row
,
F16
,
F16
,
F16
>>
;
using
RCR_F16_F16_F16_LargeK
=
ck
::
test
::
TestGroupedGemmTwoStage
<
std
::
tuple
<
Row
,
Col
,
Row
,
F16
,
F16
,
F16
>>
;
using
RRR_BF16_BF16_BF16
=
ck
::
test
::
TestGroupedGemmTwoStage
<
std
::
tuple
<
Row
,
Row
,
Row
,
BF16
,
BF16
,
BF16
>>
;
using
RCR_BF16_BF16_BF16
=
ck
::
test
::
TestGroupedGemmTwoStage
<
std
::
tuple
<
Row
,
Col
,
Row
,
BF16
,
BF16
,
BF16
>>
;
using
RRR_BF16_I8_BF16
=
ck
::
test
::
TestGroupedGemmTwoStage
<
std
::
tuple
<
Row
,
Row
,
Row
,
BF16
,
I8
,
BF16
>>
;
using
RCR_BF16_I8_BF16
=
ck
::
test
::
TestGroupedGemmTwoStage
<
std
::
tuple
<
Row
,
Col
,
Row
,
BF16
,
I8
,
BF16
>>
;
const
std
::
vector
<
int
>
KBATCH
{
1
,
2
,
3
,
5
,
8
};
INSTANTIATE_TEST_SUITE_P
(
TestGroupedGemmTwoStage_splitk_MK_KN
,
RRR_F16_F16_F16
,
testing
::
ValuesIn
(
KBATCH
));
INSTANTIATE_TEST_SUITE_P
(
TestGroupedGemmTwoStage_splitk_MK_NK
,
RCR_F16_F16_F16
,
testing
::
ValuesIn
(
KBATCH
));
INSTANTIATE_TEST_SUITE_P
(
TestGroupedGemmTwoStage_splitk_MK_KN_BF16
,
RRR_BF16_BF16_BF16
,
testing
::
ValuesIn
(
KBATCH
));
INSTANTIATE_TEST_SUITE_P
(
TestGroupedGemmTwoStage_splitk_MK_NK_BF16
,
RCR_BF16_BF16_BF16
,
testing
::
ValuesIn
(
KBATCH
));
INSTANTIATE_TEST_SUITE_P
(
TestGroupedGemmTwoStage_splitk_MK_KN_BF16_INT8
,
RRR_BF16_I8_BF16
,
testing
::
ValuesIn
(
KBATCH
));
INSTANTIATE_TEST_SUITE_P
(
TestGroupedGemmTwoStage_splitk_MK_NK_BF16_INT8
,
RCR_BF16_I8_BF16
,
testing
::
ValuesIn
(
KBATCH
));
INSTANTIATE_TEST_SUITE_P
(
TestGroupedGemmTwoStage_splitk_LargeK_MK_KN
,
RRR_F16_F16_F16_LargeK
,
testing
::
Values
(
32
,
64
));
INSTANTIATE_TEST_SUITE_P
(
TestGroupedGemmTwoStage_splitk_LargeK_MK_NK
,
RCR_F16_F16_F16_LargeK
,
testing
::
Values
(
32
,
64
));
#include "test_grouped_gemm_ut_cases.inc"
#include "test_grouped_gemm_two_stage_ut_cases.inc"
test/grouped_gemm/test_grouped_gemm_two_stage_ut_cases.inc
0 → 100644
View file @
a6ef5c39
#pragma once
TEST_P
(
RRR_BF16_BF16_BF16
,
MNKPadded
)
{
const
std
::
vector
<
int
>
Ms
{
127
,
150
,
188
,
210
};
constexpr
int
N
=
136
;
constexpr
int
K
=
280
;
const
std
::
vector
<
int
>
Ns
(
Ms
.
size
(),
N
);
const
std
::
vector
<
int
>
Ks
(
Ms
.
size
(),
K
);
const
std
::
vector
<
int
>
StrideAs
(
Ms
.
size
(),
K
);
const
std
::
vector
<
int
>
StrideBs
(
Ms
.
size
(),
N
);
const
std
::
vector
<
int
>
StrideCs
(
Ms
.
size
(),
N
);
this
->
Run
(
Ms
,
Ns
,
Ks
,
StrideAs
,
StrideBs
,
StrideCs
,
this
->
GetParam
());
}
TEST_P
(
RCR_BF16_BF16_BF16
,
MNKPadded
)
{
const
std
::
vector
<
int
>
Ms
{
127
,
150
,
188
,
210
};
constexpr
int
N
=
136
;
constexpr
int
K
=
280
;
const
std
::
vector
<
int
>
Ns
(
Ms
.
size
(),
N
);
const
std
::
vector
<
int
>
Ks
(
Ms
.
size
(),
K
);
const
std
::
vector
<
int
>
StrideAs
(
Ms
.
size
(),
K
);
const
std
::
vector
<
int
>
StrideBs
(
Ms
.
size
(),
K
);
const
std
::
vector
<
int
>
StrideCs
(
Ms
.
size
(),
N
);
this
->
Run
(
Ms
,
Ns
,
Ks
,
StrideAs
,
StrideBs
,
StrideCs
,
this
->
GetParam
());
}
TEST_P
(
RRR_BF16_I8_BF16
,
MNKPadded
)
{
const
std
::
vector
<
int
>
Ms
{
127
,
150
,
188
,
210
};
constexpr
int
N
=
136
;
constexpr
int
K
=
280
;
const
std
::
vector
<
int
>
Ns
(
Ms
.
size
(),
N
);
const
std
::
vector
<
int
>
Ks
(
Ms
.
size
(),
K
);
const
std
::
vector
<
int
>
StrideAs
(
Ms
.
size
(),
K
);
const
std
::
vector
<
int
>
StrideBs
(
Ms
.
size
(),
N
);
const
std
::
vector
<
int
>
StrideCs
(
Ms
.
size
(),
N
);
this
->
Run
(
Ms
,
Ns
,
Ks
,
StrideAs
,
StrideBs
,
StrideCs
,
this
->
GetParam
());
}
TEST_P
(
RCR_BF16_I8_BF16
,
MNKPadded
)
{
const
std
::
vector
<
int
>
Ms
{
127
,
150
,
188
,
210
};
constexpr
int
N
=
136
;
constexpr
int
K
=
280
;
const
std
::
vector
<
int
>
Ns
(
Ms
.
size
(),
N
);
const
std
::
vector
<
int
>
Ks
(
Ms
.
size
(),
K
);
const
std
::
vector
<
int
>
StrideAs
(
Ms
.
size
(),
K
);
const
std
::
vector
<
int
>
StrideBs
(
Ms
.
size
(),
K
);
const
std
::
vector
<
int
>
StrideCs
(
Ms
.
size
(),
N
);
this
->
Run
(
Ms
,
Ns
,
Ks
,
StrideAs
,
StrideBs
,
StrideCs
,
this
->
GetParam
());
}
test/grouped_gemm/test_grouped_gemm_util.hpp
View file @
a6ef5c39
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#pragma once
...
@@ -22,6 +22,7 @@
...
@@ -22,6 +22,7 @@
#include "ck/utility/tuple.hpp"
#include "ck/utility/tuple.hpp"
#include "ck/utility/number.hpp"
#include "ck/utility/number.hpp"
#include "profiler/profile_grouped_gemm_impl.hpp"
#include "profiler/profile_grouped_gemm_impl.hpp"
#include "profiler/profile_grouped_gemm_two_stage_impl.hpp"
namespace
ck
{
namespace
ck
{
namespace
test
{
namespace
test
{
...
@@ -90,6 +91,58 @@ class TestGroupedGemm : public testing::TestWithParam<int>
...
@@ -90,6 +91,58 @@ class TestGroupedGemm : public testing::TestWithParam<int>
}
}
};
};
template
<
typename
Tuple
>
class
TestGroupedGemmTwoStage
:
public
testing
::
TestWithParam
<
int
>
{
protected:
using
ALayout
=
std
::
tuple_element_t
<
0
,
Tuple
>
;
using
BLayout
=
std
::
tuple_element_t
<
1
,
Tuple
>
;
using
ELayout
=
std
::
tuple_element_t
<
2
,
Tuple
>
;
using
ADataType
=
std
::
tuple_element_t
<
3
,
Tuple
>
;
using
BDataType
=
std
::
tuple_element_t
<
4
,
Tuple
>
;
using
EDataType
=
std
::
tuple_element_t
<
5
,
Tuple
>
;
public:
static
constexpr
bool
verify_
=
true
;
static
constexpr
int
init_method_
=
1
;
// decimal value initialization
static
constexpr
bool
log_
=
false
;
static
constexpr
bool
bench_
=
false
;
// measure kernel performance
void
SetUp
()
override
{}
void
Run
(
const
std
::
vector
<
int
>&
Ms
,
const
std
::
vector
<
int
>&
Ns
,
const
std
::
vector
<
int
>&
Ks
,
const
std
::
vector
<
int
>&
StrideAs
,
const
std
::
vector
<
int
>&
StrideBs
,
const
std
::
vector
<
int
>&
StrideCs
,
int
kbatch
=
1
,
int
n_warmup
=
1
,
int
n_iter
=
10
)
{
bool
pass
=
ck
::
profiler
::
profile_grouped_gemm_two_stage_impl
<
ADataType
,
BDataType
,
EDataType
,
float
,
ALayout
,
BLayout
,
ELayout
>
(
verify_
,
init_method_
,
log_
,
bench_
,
Ms
,
Ns
,
Ks
,
StrideAs
,
StrideBs
,
StrideCs
,
kbatch
,
n_warmup
,
n_iter
);
EXPECT_TRUE
(
pass
);
}
};
template
<
typename
ALayout
,
template
<
typename
ALayout
,
typename
BLayout
,
typename
BLayout
,
typename
ELayout
,
typename
ELayout
,
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment