Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
328ab6f3
Commit
328ab6f3
authored
Apr 30, 2022
by
Jianfeng yan
Browse files
removed A/B/CGridDesc from DeviceOps that use gridwise_gemm_v2r3 and gridwise_gemm_cshuffle
parent
e739c577
Changes
15
Show whitespace changes
Inline
Side-by-side
Showing
15 changed files
with
276 additions
and
269 deletions
+276
-269
include/ck/tensor_operation/gpu/device/device_batched_gemm_xdl.hpp
...k/tensor_operation/gpu/device/device_batched_gemm_xdl.hpp
+2
-5
include/ck/tensor_operation/gpu/device/device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk.hpp
.../gpu/device/device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk.hpp
+7
-6
include/ck/tensor_operation/gpu/device/device_conv2d_fwd_xdl_nhwc_kyxc_nhwk.hpp
...ation/gpu/device/device_conv2d_fwd_xdl_nhwc_kyxc_nhwk.hpp
+7
-7
include/ck/tensor_operation/gpu/device/device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk.hpp
...on/gpu/device/device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk.hpp
+0
-3
include/ck/tensor_operation/gpu/device/device_convnd_bwd_data_xdl_ndhwc_kzyxc_ndhwk.hpp
...u/device/device_convnd_bwd_data_xdl_ndhwc_kzyxc_ndhwk.hpp
+9
-11
include/ck/tensor_operation/gpu/device/device_convnd_fwd_xdl_nhwc_kyxc_nhwk.hpp
...ation/gpu/device/device_convnd_fwd_xdl_nhwc_kyxc_nhwk.hpp
+7
-8
include/ck/tensor_operation/gpu/device/device_gemm_xdl.hpp
include/ck/tensor_operation/gpu/device/device_gemm_xdl.hpp
+13
-9
include/ck/tensor_operation/gpu/device/device_gemm_xdl_cshuffle.hpp
.../tensor_operation/gpu/device/device_gemm_xdl_cshuffle.hpp
+28
-28
include/ck/tensor_operation/gpu/device/device_gemm_xdl_splitk_c_shuffle.hpp
...operation/gpu/device/device_gemm_xdl_splitk_c_shuffle.hpp
+5
-3
include/ck/tensor_operation/gpu/device/device_grouped_gemm_xdl.hpp
...k/tensor_operation/gpu/device/device_grouped_gemm_xdl.hpp
+6
-7
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v1.hpp
...nsor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v1.hpp
+23
-11
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r3.hpp
...k/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r3.hpp
+0
-2
library/src/tensor_operation_instance/gpu/gemm/CMakeLists.txt
...ary/src/tensor_operation_instance/gpu/gemm/CMakeLists.txt
+8
-8
profiler/include/profile_gemm_impl.hpp
profiler/include/profile_gemm_impl.hpp
+160
-160
test/CMakeLists.txt
test/CMakeLists.txt
+1
-1
No files found.
include/ck/tensor_operation/gpu/device/device_batched_gemm_xdl.hpp
View file @
328ab6f3
...
...
@@ -222,9 +222,6 @@ struct DeviceBatchedGemmXdl
AccDataType
,
CDataType
,
InMemoryDataOperationEnum
::
Set
,
AGridDesc_K0_M_K1
,
BGridDesc_K0_N_K1
,
CGridDesc_M_N
,
AElementwiseOperation
,
BElementwiseOperation
,
CElementwiseOperation
,
...
...
@@ -373,7 +370,7 @@ struct DeviceBatchedGemmXdl
CDataType
,
remove_reference_t
<
DeviceBatchedGemmXdl
::
AGridDesc_K0_M_K1
>
,
remove_reference_t
<
DeviceBatchedGemmXdl
::
BGridDesc_K0_N_K1
>
,
remove_reference_t
<
typename
GridwiseGemm
::
CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2
>
,
remove_reference_t
<
CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2
>
,
AElementwiseOperation
,
BElementwiseOperation
,
CElementwiseOperation
,
...
...
@@ -407,7 +404,7 @@ struct DeviceBatchedGemmXdl
CDataType
,
remove_reference_t
<
DeviceBatchedGemmXdl
::
AGridDesc_K0_M_K1
>
,
remove_reference_t
<
DeviceBatchedGemmXdl
::
BGridDesc_K0_N_K1
>
,
remove_reference_t
<
typename
GridwiseGemm
::
CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2
>
,
remove_reference_t
<
CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2
>
,
AElementwiseOperation
,
BElementwiseOperation
,
CElementwiseOperation
,
...
...
include/ck/tensor_operation/gpu/device/device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk.hpp
View file @
328ab6f3
...
...
@@ -403,6 +403,9 @@ struct DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
7
,
// CThreadTransferSrcDstVectorDim,
CThreadTransferDstScalarPerVector
>
;
using
CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2
=
decltype
(
GridwiseGemm
::
MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
(
CGridDesc_M_N
{}));
// Argument
struct
Argument
:
public
BaseArgument
{
...
...
@@ -492,7 +495,7 @@ struct DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
GridwiseGemm
::
MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
(
descs
[
I2
]));
block_2_ctile_map_container_
.
push_back
(
GridwiseGemm
::
MakeDefaultBlock2CTileMap
(
descs
[
I2
],
M01
,
N01
));
GridwiseGemm
::
MakeDefaultBlock2CTileMap
(
descs
[
I2
]
.
GetLength
(
I0
),
descs
[
I2
].
GetLength
(
I1
)
,
M01
,
N01
));
}
}
}
...
...
@@ -504,7 +507,7 @@ struct DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
std
::
vector
<
AGridDesc_K0_M_K1
>
a_grid_desc_k0_m_k1_container_
;
std
::
vector
<
BGridDesc_K0_N_K1
>
b_grid_desc_k0_n_k1_container_
;
std
::
vector
<
CGridDesc_M_N
>
c_grid_desc_m_n_container_
;
std
::
vector
<
typename
GridwiseGemm
::
CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2
>
std
::
vector
<
CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2
>
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_container_
;
std
::
vector
<
typename
GridwiseGemm
::
DefaultBlock2CTileMap
>
block_2_ctile_map_container_
;
index_t
M01_
;
...
...
@@ -594,8 +597,7 @@ struct DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
CDataType
,
remove_reference_t
<
DeviceOp
::
AGridDesc_K0_M_K1
>
,
remove_reference_t
<
DeviceOp
::
BGridDesc_K0_N_K1
>
,
remove_reference_t
<
typename
GridwiseGemm
::
CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2
>
,
remove_reference_t
<
CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2
>
,
OutElementwiseOperation
,
WeiElementwiseOperation
,
InElementwiseOperation
,
...
...
@@ -627,8 +629,7 @@ struct DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
CDataType
,
remove_reference_t
<
DeviceOp
::
AGridDesc_K0_M_K1
>
,
remove_reference_t
<
DeviceOp
::
BGridDesc_K0_N_K1
>
,
remove_reference_t
<
typename
GridwiseGemm
::
CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2
>
,
remove_reference_t
<
CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2
>
,
OutElementwiseOperation
,
WeiElementwiseOperation
,
InElementwiseOperation
,
...
...
include/ck/tensor_operation/gpu/device/device_conv2d_fwd_xdl_nhwc_kyxc_nhwk.hpp
View file @
328ab6f3
...
...
@@ -317,9 +317,6 @@ struct DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
AccDataType
,
CDataType
,
InMemoryDataOperationEnum
::
Set
,
AGridDesc_K0_M_K1
,
BGridDesc_K0_N_K1
,
CGridDesc_M_N
,
InElementwiseOperation
,
WeiElementwiseOperation
,
OutElementwiseOperation
,
...
...
@@ -351,6 +348,9 @@ struct DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
7
,
// CThreadTransferSrcDstVectorDim,
CThreadTransferDstScalarPerVector
>
;
using
CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2
=
decltype
(
GridwiseGemm
::
MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
(
CGridDesc_M_N
{}));
// Argument
struct
Argument
:
public
BaseArgument
{
...
...
@@ -416,7 +416,7 @@ struct DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
GridwiseGemm
::
MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
(
c_grid_desc_m_n_
);
block_2_ctile_map_
=
GridwiseGemm
::
MakeDefaultBlock2CTileMap
(
c_grid_desc_m_n_
,
M01
,
N01
);
GridwiseGemm
::
MakeDefaultBlock2CTileMap
(
c_grid_desc_m_n_
.
GetLength
(
I0
),
c_grid_desc_m_n_
.
GetLength
(
I1
)
,
M01
,
N01
);
}
}
...
...
@@ -427,7 +427,7 @@ struct DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
AGridDesc_K0_M_K1
a_grid_desc_k0_m_k1_
;
BGridDesc_K0_N_K1
b_grid_desc_k0_n_k1_
;
CGridDesc_M_N
c_grid_desc_m_n_
;
typename
GridwiseGemm
::
CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2
CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_
;
typename
GridwiseGemm
::
DefaultBlock2CTileMap
block_2_ctile_map_
;
index_t
M01_
;
...
...
@@ -492,7 +492,7 @@ struct DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
CDataType
,
remove_reference_t
<
DeviceOp
::
AGridDesc_K0_M_K1
>
,
remove_reference_t
<
DeviceOp
::
BGridDesc_K0_N_K1
>
,
remove_reference_t
<
typename
GridwiseGemm
::
CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2
>
,
remove_reference_t
<
CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2
>
,
InElementwiseOperation
,
WeiElementwiseOperation
,
OutElementwiseOperation
,
...
...
@@ -523,7 +523,7 @@ struct DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
CDataType
,
remove_reference_t
<
DeviceOp
::
AGridDesc_K0_M_K1
>
,
remove_reference_t
<
DeviceOp
::
BGridDesc_K0_N_K1
>
,
remove_reference_t
<
typename
GridwiseGemm
::
CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2
>
,
remove_reference_t
<
CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2
>
,
InElementwiseOperation
,
WeiElementwiseOperation
,
OutElementwiseOperation
,
...
...
include/ck/tensor_operation/gpu/device/device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk.hpp
View file @
328ab6f3
...
...
@@ -309,9 +309,6 @@ struct DeviceConv3dFwdXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_
AccDataType
,
OutDataType
,
InMemoryDataOperationEnum
::
Set
,
AGridDesc_K0_M_K1
,
BGridDesc_K0_N_K1
,
CGridDesc_M_N
,
InElementwiseOperation
,
WeiElementwiseOperation
,
OutElementwiseOperation
,
...
...
include/ck/tensor_operation/gpu/device/device_convnd_bwd_data_xdl_ndhwc_kzyxc_ndhwk.hpp
View file @
328ab6f3
...
...
@@ -960,9 +960,6 @@ struct DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho
AccDataType
,
CDataType
,
InMemoryDataOperationEnum
::
Set
,
AGridDesc_K0_M_K1
,
BGridDesc_K0_N_K1
,
CGridDesc_M_N
,
InElementwiseOperation
,
WeiElementwiseOperation
,
OutElementwiseOperation
,
...
...
@@ -994,6 +991,9 @@ struct DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho
7
,
// CThreadTransferSrcDstVectorDim,
CThreadTransferDstScalarPerVector
>
;
using
CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2
=
decltype
(
GridwiseGemm
::
MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
(
CGridDesc_M_N
{}));
// Argument
struct
Argument
:
public
BaseArgument
{
...
...
@@ -1079,7 +1079,7 @@ struct DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho
GridwiseGemm
::
MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
(
descs
[
I2
]));
block_2_ctile_map_container_
.
push_back
(
GridwiseGemm
::
MakeDefaultBlock2CTileMap
(
descs
[
I2
],
M01_
,
N01_
));
GridwiseGemm
::
MakeDefaultBlock2CTileMap
(
descs
[
I2
]
.
GetLength
(
I0
),
descs
[
I2
].
GetLength
(
I1
)
,
M01_
,
N01_
));
}
}
}
...
...
@@ -1135,7 +1135,7 @@ struct DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho
GridwiseGemm
::
MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
(
descs
[
I2
]));
block_2_ctile_map_container_
.
push_back
(
GridwiseGemm
::
MakeDefaultBlock2CTileMap
(
descs
[
I2
],
M01_
,
N01_
));
GridwiseGemm
::
MakeDefaultBlock2CTileMap
(
descs
[
I2
]
.
GetLength
(
I0
),
descs
[
I2
].
GetLength
(
I1
)
,
M01_
,
N01_
));
}
}
}
...
...
@@ -1201,7 +1201,7 @@ struct DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho
descs
[
I2
]));
block_2_ctile_map_container_
.
push_back
(
GridwiseGemm
::
MakeDefaultBlock2CTileMap
(
descs
[
I2
],
M01_
,
N01_
));
GridwiseGemm
::
MakeDefaultBlock2CTileMap
(
descs
[
I2
]
.
GetLength
(
I0
),
descs
[
I2
].
GetLength
(
I1
)
,
M01_
,
N01_
));
}
}
}
...
...
@@ -1214,7 +1214,7 @@ struct DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho
std
::
vector
<
AGridDesc_K0_M_K1
>
a_grid_desc_k0_m_k1_container_
;
std
::
vector
<
BGridDesc_K0_N_K1
>
b_grid_desc_k0_n_k1_container_
;
std
::
vector
<
CGridDesc_M_N
>
c_grid_desc_m_n_container_
;
std
::
vector
<
typename
GridwiseGemm
::
CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2
>
std
::
vector
<
CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2
>
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_container_
;
std
::
vector
<
typename
GridwiseGemm
::
DefaultBlock2CTileMap
>
block_2_ctile_map_container_
;
index_t
M01_
;
...
...
@@ -1308,8 +1308,7 @@ struct DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho
CDataType
,
remove_reference_t
<
DeviceOp
::
AGridDesc_K0_M_K1
>
,
remove_reference_t
<
DeviceOp
::
BGridDesc_K0_N_K1
>
,
remove_reference_t
<
typename
GridwiseGemm
::
CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2
>
,
remove_reference_t
<
CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2
>
,
OutElementwiseOperation
,
WeiElementwiseOperation
,
InElementwiseOperation
,
...
...
@@ -1341,8 +1340,7 @@ struct DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho
CDataType
,
remove_reference_t
<
DeviceOp
::
AGridDesc_K0_M_K1
>
,
remove_reference_t
<
DeviceOp
::
BGridDesc_K0_N_K1
>
,
remove_reference_t
<
typename
GridwiseGemm
::
CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2
>
,
remove_reference_t
<
CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2
>
,
OutElementwiseOperation
,
WeiElementwiseOperation
,
InElementwiseOperation
,
...
...
include/ck/tensor_operation/gpu/device/device_convnd_fwd_xdl_nhwc_kyxc_nhwk.hpp
View file @
328ab6f3
...
...
@@ -614,9 +614,6 @@ struct DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
AccDataType
,
CDataType
,
InMemoryDataOperationEnum
::
Set
,
AGridDesc_K0_M_K1
,
BGridDesc_K0_N_K1
,
CGridDesc_M_N
,
InElementwiseOperation
,
WeiElementwiseOperation
,
OutElementwiseOperation
,
...
...
@@ -648,6 +645,9 @@ struct DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
7
,
// CThreadTransferSrcDstVectorDim,
CThreadTransferDstScalarPerVector
>
;
using
CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2
=
decltype
(
GridwiseGemm
::
MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
(
CGridDesc_M_N
{}));
// Argument
struct
Argument
:
public
BaseArgument
{
...
...
@@ -713,7 +713,7 @@ struct DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
GridwiseGemm
::
MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
(
c_grid_desc_m_n_
);
block_2_ctile_map_
=
GridwiseGemm
::
MakeDefaultBlock2CTileMap
(
c_grid_desc_m_n_
,
M01
,
N01
);
GridwiseGemm
::
MakeDefaultBlock2CTileMap
(
c_grid_desc_m_n_
.
GetLength
(
I0
),
c_grid_desc_m_n_
.
GetLength
(
I1
)
,
M01
,
N01
);
}
}
...
...
@@ -724,8 +724,7 @@ struct DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
AGridDesc_K0_M_K1
a_grid_desc_k0_m_k1_
;
BGridDesc_K0_N_K1
b_grid_desc_k0_n_k1_
;
CGridDesc_M_N
c_grid_desc_m_n_
;
typename
GridwiseGemm
::
CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_
;
CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_
;
typename
GridwiseGemm
::
DefaultBlock2CTileMap
block_2_ctile_map_
;
index_t
M01_
;
index_t
N01_
;
...
...
@@ -789,7 +788,7 @@ struct DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
CDataType
,
remove_reference_t
<
DeviceOp
::
AGridDesc_K0_M_K1
>
,
remove_reference_t
<
DeviceOp
::
BGridDesc_K0_N_K1
>
,
remove_reference_t
<
typename
GridwiseGemm
::
CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2
>
,
remove_reference_t
<
CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2
>
,
InElementwiseOperation
,
WeiElementwiseOperation
,
OutElementwiseOperation
,
...
...
@@ -820,7 +819,7 @@ struct DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
CDataType
,
remove_reference_t
<
DeviceOp
::
AGridDesc_K0_M_K1
>
,
remove_reference_t
<
DeviceOp
::
BGridDesc_K0_N_K1
>
,
remove_reference_t
<
typename
GridwiseGemm
::
CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2
>
,
remove_reference_t
<
CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2
>
,
InElementwiseOperation
,
WeiElementwiseOperation
,
OutElementwiseOperation
,
...
...
include/ck/tensor_operation/gpu/device/device_gemm_xdl.hpp
View file @
328ab6f3
...
...
@@ -187,9 +187,9 @@ struct DeviceGemmXdl
AccDataType
,
CDataType
,
InMemoryDataOperationEnum
::
Set
,
AGridDesc_K0_M_K1
,
BGridDesc_K0_N_K1
,
CGridDesc_M_N
,
//
AGridDesc_K0_M_K1,
//
BGridDesc_K0_N_K1,
//
CGridDesc_M_N,
AElementwiseOperation
,
BElementwiseOperation
,
CElementwiseOperation
,
...
...
@@ -222,6 +222,11 @@ struct DeviceGemmXdl
CThreadTransferDstScalarPerVector
,
NumPrefetch
>
;
using
CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2
=
decltype
(
GridwiseGemm
::
MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
(
CGridDesc_M_N
{}));
using
Block2CTileMap
=
decltype
(
GridwiseGemm
::
MakeDefaultBlock2CTileMap
(
1
,
1
,
1
,
1
));
// Argument
struct
Argument
:
public
BaseArgument
{
...
...
@@ -264,7 +269,7 @@ struct DeviceGemmXdl
GridwiseGemm
::
MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
(
c_grid_desc_m_n_
);
block_2_ctile_map_
=
GridwiseGemm
::
MakeDefaultBlock2CTileMap
(
c_grid_desc_m_n_
,
M01
,
N01
);
GridwiseGemm
::
MakeDefaultBlock2CTileMap
(
c_grid_desc_m_n_
.
GetLength
(
I0
),
c_grid_desc_m_n_
.
GetLength
(
I1
)
,
M01
,
N01
);
}
}
...
...
@@ -275,9 +280,8 @@ struct DeviceGemmXdl
AGridDesc_K0_M_K1
a_grid_desc_k0_m_k1_
;
BGridDesc_K0_N_K1
b_grid_desc_k0_n_k1_
;
CGridDesc_M_N
c_grid_desc_m_n_
;
typename
GridwiseGemm
::
CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_
;
typename
GridwiseGemm
::
DefaultBlock2CTileMap
block_2_ctile_map_
;
CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_
;
Block2CTileMap
block_2_ctile_map_
;
index_t
M01_
;
index_t
N01_
;
AElementwiseOperation
a_element_op_
;
...
...
@@ -331,7 +335,7 @@ struct DeviceGemmXdl
CDataType
,
remove_reference_t
<
DeviceGemmXdl
::
AGridDesc_K0_M_K1
>
,
remove_reference_t
<
DeviceGemmXdl
::
BGridDesc_K0_N_K1
>
,
remove_reference_t
<
typename
GridwiseGemm
::
CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2
>
,
CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2
,
AElementwiseOperation
,
BElementwiseOperation
,
CElementwiseOperation
,
...
...
@@ -362,7 +366,7 @@ struct DeviceGemmXdl
CDataType
,
remove_reference_t
<
DeviceGemmXdl
::
AGridDesc_K0_M_K1
>
,
remove_reference_t
<
DeviceGemmXdl
::
BGridDesc_K0_N_K1
>
,
remove_reference_t
<
typename
GridwiseGemm
::
CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2
>
,
CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2
,
AElementwiseOperation
,
BElementwiseOperation
,
CElementwiseOperation
,
...
...
include/ck/tensor_operation/gpu/device/device_gemm_xdl_cshuffle.hpp
View file @
328ab6f3
...
...
@@ -342,9 +342,6 @@ struct DeviceGemm_Xdl_CShuffle
BElementwiseOperation
,
CElementwiseOperation
,
InMemoryDataOperationEnum
::
Set
,
AGridDesc_AK0_M_AK1
,
BGridDesc_BK0_N_BK1
,
CGridDesc_M_N
,
NumGemmKPrefetchStage
,
BlockSize
,
MPerBlock
,
...
...
@@ -377,6 +374,9 @@ struct DeviceGemm_Xdl_CShuffle
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
CShuffleBlockTransferScalarPerVector_NPerBlock
>
;
using
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
=
decltype
(
GridwiseGemm
::
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
CGridDesc_M_N
{}));
// Argument
struct
Argument
:
public
BaseArgument
{
...
...
@@ -422,7 +422,7 @@ struct DeviceGemm_Xdl_CShuffle
AGridDesc_AK0_M_AK1
a_grid_desc_ak0_m_ak1_
;
BGridDesc_BK0_N_BK1
b_grid_desc_bk0_n_bk1_
;
CGridDesc_M_N
c_grid_desc_m_n_
;
typename
GridwiseGemm
::
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
c_grid_desc_mblock_mperblock_nblock_nperblock_
;
typename
GridwiseGemm
::
DefaultBlock2CTileMap
block_2_ctile_map_
;
AElementwiseOperation
a_element_op_
;
...
...
@@ -470,8 +470,8 @@ struct DeviceGemm_Xdl_CShuffle
if
(
has_main_k0_block_loop
)
{
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v1
<
GridwiseGemm
,
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v1
<
GridwiseGemm
,
ADataType
,
// TODO: distiguish A/B datatype
CDataType
,
AElementwiseOperation
,
...
...
@@ -479,7 +479,7 @@ struct DeviceGemm_Xdl_CShuffle
CElementwiseOperation
,
DeviceOp
::
AGridDesc_AK0_M_AK1
,
DeviceOp
::
BGridDesc_BK0_N_BK1
,
typename
GridwiseGemm
::
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
GridwiseGemm
::
DefaultBlock2CTileMap
,
true
>
;
...
...
@@ -522,8 +522,8 @@ struct DeviceGemm_Xdl_CShuffle
}
else
{
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v1
<
GridwiseGemm
,
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v1
<
GridwiseGemm
,
ADataType
,
// TODO: distiguish A/B datatype
CDataType
,
AElementwiseOperation
,
...
...
@@ -531,7 +531,7 @@ struct DeviceGemm_Xdl_CShuffle
CElementwiseOperation
,
DeviceOp
::
AGridDesc_AK0_M_AK1
,
DeviceOp
::
BGridDesc_BK0_N_BK1
,
typename
GridwiseGemm
::
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
GridwiseGemm
::
DefaultBlock2CTileMap
,
false
>
;
...
...
include/ck/tensor_operation/gpu/device/device_gemm_xdl_splitk_c_shuffle.hpp
View file @
328ab6f3
...
...
@@ -419,6 +419,8 @@ struct DeviceGemmXdlSplitKCShuffle
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
CShuffleBlockTransferScalarPerVector_NPerBlock
>
;
using
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
=
decltype
(
GridwiseGemm
::
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
CGridDesc_M_N
{}));
using
Block2CTileMap
=
decltype
(
BatchedGemmUtil
::
MakeBlock2CTileMap
<
MPerBlock
,
NPerBlock
>
(
1
,
1
,
1
));
struct
Argument
:
public
BaseArgument
...
...
@@ -505,7 +507,7 @@ struct DeviceGemmXdlSplitKCShuffle
AGridDesc_AK0_M_AK1
a_grid_desc_ak0_m_ak1_
;
BGridDesc_BK0_N_BK1
b_grid_desc_bk0_n_bk1_
;
CGridDesc_M_N
c_grid_desc_m_n_
;
typename
GridwiseGemm
::
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
c_grid_desc_mblock_mperblock_nblock_nperblock_
;
ComputePtrOffsetOfStridedBatch
compute_ptr_offset_of_batch_
;
Block2CTileMap
block_2_ctile_map_
;
...
...
@@ -564,7 +566,7 @@ struct DeviceGemmXdlSplitKCShuffle
CElementwiseOperation
,
DeviceOp
::
AGridDesc_AK0_M_AK1
,
DeviceOp
::
BGridDesc_BK0_N_BK1
,
typename
GridwiseGemm
::
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
ComputePtrOffsetOfStridedBatch
,
Block2CTileMap
,
true
>
;
...
...
@@ -621,7 +623,7 @@ struct DeviceGemmXdlSplitKCShuffle
CElementwiseOperation
,
DeviceOp
::
AGridDesc_AK0_M_AK1
,
DeviceOp
::
BGridDesc_BK0_N_BK1
,
typename
GridwiseGemm
::
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
ComputePtrOffsetOfStridedBatch
,
Block2CTileMap
,
false
>
;
...
...
include/ck/tensor_operation/gpu/device/device_grouped_gemm_xdl.hpp
View file @
328ab6f3
...
...
@@ -188,9 +188,6 @@ struct DeviceGroupedGemmXdl
AccDataType
,
CDataType
,
InMemoryDataOperationEnum
::
Set
,
AGridDesc_K0_M_K1
,
BGridDesc_K0_N_K1
,
CGridDesc_M_N
,
AElementwiseOperation
,
BElementwiseOperation
,
CElementwiseOperation
,
...
...
@@ -223,11 +220,14 @@ struct DeviceGroupedGemmXdl
CThreadTransferDstScalarPerVector
,
NumPrefetch
>
;
using
CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2
=
decltype
(
GridwiseGemm
::
MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
(
CGridDesc_M_N
{}));
struct
GroupedGemmBlock2CTileMap
{
GroupedGemmBlock2CTileMap
()
{
block_2_ctile_map_
=
GridwiseGemm
::
MakeDefaultBlock2CTileMap
(
CGridDesc_M_N
{}
,
1
,
1
);
block_2_ctile_map_
=
GridwiseGemm
::
MakeDefaultBlock2CTileMap
(
1
,
1
,
1
,
1
);
BlockStart_
=
-
1
;
}
...
...
@@ -236,7 +236,7 @@ struct DeviceGroupedGemmXdl
index_t
N01
,
ck
::
index_t
BlockStart
)
{
block_2_ctile_map_
=
GridwiseGemm
::
MakeDefaultBlock2CTileMap
(
c_grid_desc_m_n
,
M01
,
N01
);
block_2_ctile_map_
=
GridwiseGemm
::
MakeDefaultBlock2CTileMap
(
c_grid_desc_m_n
.
GetLength
(
I0
),
c_grid_desc_m_n
.
GetLength
(
I1
)
,
M01
,
N01
);
BlockStart_
=
BlockStart
;
}
...
...
@@ -258,8 +258,7 @@ struct DeviceGroupedGemmXdl
BGridDesc_K0_N_K1
b_grid_desc_k0_n_k1_
;
CGridDesc_M_N
c_grid_desc_m_n_
;
typename
GridwiseGemm
::
CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_
;
CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_
;
GroupedGemmBlock2CTileMap
grouped_gemm_block_2_ctile_map_
;
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v1.hpp
View file @
328ab6f3
...
...
@@ -150,9 +150,6 @@ template <typename FloatAB,
typename
BElementwiseOperation
,
typename
CElementwiseOperation
,
InMemoryDataOperationEnum
CGlobalMemoryDataOperation
,
typename
AGridDesc_AK0_M_AK1
,
typename
BGridDesc_BK0_N_BK1
,
typename
CGridDesc_M_N
,
index_t
NumGemmKPrefetchStage
,
index_t
BlockSize
,
index_t
MPerBlock
,
...
...
@@ -261,6 +258,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
}
// block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
template
<
typename
AGridDesc_AK0_M_AK1
,
typename
BGridDesc_BK0_N_BK1
,
typename
CGridDesc_M_N
>
__host__
__device__
static
constexpr
bool
CheckValidity
(
const
AGridDesc_AK0_M_AK1
&
a_grid_desc_ak0_m_ak1
,
const
BGridDesc_BK0_N_BK1
&
b_grid_desc_bk0_n_bk1
,
...
...
@@ -307,9 +305,12 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
return
true
;
}
template
<
typename
CGridDesc_M_N
>
__host__
__device__
static
constexpr
index_t
CalculateGridSize
(
const
CGridDesc_M_N
&
c_grid_desc_m_n
)
{
static_assert
(
CGridDesc_M_N
::
GetNumOfVisibleDimension
()
==
2
);
const
auto
M
=
c_grid_desc_m_n
.
GetLength
(
I0
);
const
auto
N
=
c_grid_desc_m_n
.
GetLength
(
I1
);
...
...
@@ -326,9 +327,12 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
return
has_main_k0_block_loop
;
}
template
<
typename
CGridDesc_M_N
>
__host__
__device__
static
constexpr
auto
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
const
CGridDesc_M_N
&
c_grid_desc_m_n
)
{
static_assert
(
CGridDesc_M_N
::
GetNumOfVisibleDimension
()
==
2
);
const
auto
M
=
c_grid_desc_m_n
.
GetLength
(
I0
);
const
auto
N
=
c_grid_desc_m_n
.
GetLength
(
I1
);
...
...
@@ -347,11 +351,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
// return block_id to C matrix tile idx (m0, n0) mapping
__host__
__device__
static
constexpr
auto
MakeDefaultBlock2CTileMap
(
const
CGridDesc_M_N
&
c_grid_desc_m_n
)
MakeDefaultBlock2CTileMap
(
index_t
M
,
index_t
N
)
{
const
auto
M
=
c_grid_desc_m_n
.
GetLength
(
I0
);
const
auto
N
=
c_grid_desc_m_n
.
GetLength
(
I1
);
constexpr
auto
M1
=
Number
<
MPerBlock
>
{};
constexpr
auto
N1
=
Number
<
NPerBlock
>
{};
...
...
@@ -385,13 +386,24 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
return
cblockid_to_m0_n0_block_cluster_adaptor
;
}
using
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
=
remove_cvref_t
<
decltype
(
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
CGridDesc_M_N
{}))
>
;
template
<
typename
CGridDesc_M_N
>
__host__
__device__
static
constexpr
auto
MakeDefaultBlock2CTileMap
(
const
CGridDesc_M_N
&
c_grid_desc_m_n
)
{
static_assert
(
CGridDesc_M_N
::
GetNumOfVisibleDimension
()
==
2
);
return
MakeDefaultBlock2CTileMap
(
c_grid_desc_m_n
.
GetLength
(
I0
),
c_grid_desc_m_n
.
GetLength
(
I1
));
}
using
DefaultBlock2CTileMap
=
remove_cvref_t
<
decltype
(
MakeDefaultBlock2CTileMap
(
CGridDesc_M_N
{}
))
>
;
remove_cvref_t
<
decltype
(
MakeDefaultBlock2CTileMap
(
1
,
1
))
>
;
template
<
bool
HasMainK0BlockLoop
,
typename
Block2CTileMap
>
template
<
bool
HasMainK0BlockLoop
,
typename
AGridDesc_AK0_M_AK1
,
typename
BGridDesc_BK0_N_BK1
,
typename
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
Block2CTileMap
>
__device__
static
void
Run
(
const
FloatAB
*
__restrict__
p_a_grid
,
const
FloatAB
*
__restrict__
p_b_grid
,
FloatC
*
__restrict__
p_c_grid
,
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r3.hpp
View file @
328ab6f3
...
...
@@ -525,8 +525,6 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
return
cblockid_to_m0_n0_block_cluster_adaptor
;
}
// using CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2 =
// decltype(MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(CGridDesc_M_N{}));
using
DefaultBlock2CTileMap
=
decltype
(
MakeDefaultBlock2CTileMap
(
1
,
1
,
1
,
1
));
template
<
bool
HasMainK0BlockLoop
,
...
...
library/src/tensor_operation_instance/gpu/gemm/CMakeLists.txt
View file @
328ab6f3
# device_gemm_instance
set
(
DEVICE_GEMM_INSTANCE_SOURCE
#
device_gemm_xdl_f32_f32_f32_mk_kn_mn_instance.cpp;
#
device_gemm_xdl_f32_f32_f32_mk_nk_mn_instance.cpp;
#
device_gemm_xdl_f32_f32_f32_km_kn_mn_instance.cpp;
#
device_gemm_xdl_f32_f32_f32_km_nk_mn_instance.cpp;
#
device_gemm_xdl_f16_f16_f16_mk_kn_mn_instance.cpp;
#
device_gemm_xdl_f16_f16_f16_mk_nk_mn_instance.cpp;
#
device_gemm_xdl_f16_f16_f16_km_kn_mn_instance.cpp;
#
device_gemm_xdl_f16_f16_f16_km_nk_mn_instance.cpp;
device_gemm_xdl_f32_f32_f32_mk_kn_mn_instance.cpp;
device_gemm_xdl_f32_f32_f32_mk_nk_mn_instance.cpp;
device_gemm_xdl_f32_f32_f32_km_kn_mn_instance.cpp;
device_gemm_xdl_f32_f32_f32_km_nk_mn_instance.cpp;
device_gemm_xdl_f16_f16_f16_mk_kn_mn_instance.cpp;
device_gemm_xdl_f16_f16_f16_mk_nk_mn_instance.cpp;
device_gemm_xdl_f16_f16_f16_km_kn_mn_instance.cpp;
device_gemm_xdl_f16_f16_f16_km_nk_mn_instance.cpp;
# device_gemm_xdl_c_shuffle_int8_int8_int8_mk_kn_mn_instance.cpp;
# device_gemm_xdl_c_shuffle_int8_int8_int8_mk_nk_mn_instance.cpp;
# device_gemm_xdl_c_shuffle_int8_int8_int8_km_kn_mn_instance.cpp;
...
...
profiler/include/profile_gemm_impl.hpp
View file @
328ab6f3
...
...
@@ -23,10 +23,10 @@ using DeviceGemmNoOpPtr =
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
>
;
//
void add_device_gemm_xdl_f16_f16_f16_mk_kn_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
//
void add_device_gemm_xdl_f16_f16_f16_mk_nk_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
//
void add_device_gemm_xdl_f16_f16_f16_km_kn_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
//
void add_device_gemm_xdl_f16_f16_f16_km_nk_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
void
add_device_gemm_xdl_f16_f16_f16_mk_kn_mn_instances
(
std
::
vector
<
DeviceGemmNoOpPtr
>&
);
void
add_device_gemm_xdl_f16_f16_f16_mk_nk_mn_instances
(
std
::
vector
<
DeviceGemmNoOpPtr
>&
);
void
add_device_gemm_xdl_f16_f16_f16_km_kn_mn_instances
(
std
::
vector
<
DeviceGemmNoOpPtr
>&
);
void
add_device_gemm_xdl_f16_f16_f16_km_nk_mn_instances
(
std
::
vector
<
DeviceGemmNoOpPtr
>&
);
// void add_device_gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_kn_mn_instances(
// std::vector<DeviceGemmNoOpPtr>&);
...
...
@@ -54,10 +54,10 @@ using DeviceGemmNoOpPtr =
// void add_device_gemm_xdl_c_shuffle_2_stage_f16_f16_f16_mk_nk_mn_instances(
// std::vector<DeviceGemmNoOpPtr>&);
//
void add_device_gemm_xdl_f32_f32_f32_mk_kn_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
//
void add_device_gemm_xdl_f32_f32_f32_mk_nk_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
//
void add_device_gemm_xdl_f32_f32_f32_km_kn_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
//
void add_device_gemm_xdl_f32_f32_f32_km_nk_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
void
add_device_gemm_xdl_f32_f32_f32_mk_kn_mn_instances
(
std
::
vector
<
DeviceGemmNoOpPtr
>&
);
void
add_device_gemm_xdl_f32_f32_f32_mk_nk_mn_instances
(
std
::
vector
<
DeviceGemmNoOpPtr
>&
);
void
add_device_gemm_xdl_f32_f32_f32_km_kn_mn_instances
(
std
::
vector
<
DeviceGemmNoOpPtr
>&
);
void
add_device_gemm_xdl_f32_f32_f32_km_nk_mn_instances
(
std
::
vector
<
DeviceGemmNoOpPtr
>&
);
// void add_device_gemm_xdl_c_shuffle_f32_f32_f32_mk_kn_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
// void add_device_gemm_xdl_c_shuffle_f32_f32_f32_mk_nk_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
...
...
@@ -171,8 +171,8 @@ void profile_gemm_impl(int do_verification,
}
else
{
//
ck::tensor_operation::device::device_gemm_instance::
//
add_device_gemm_xdl_f32_f32_f32_mk_kn_mn_instances(gemm_ptrs);
ck
::
tensor_operation
::
device
::
device_gemm_instance
::
add_device_gemm_xdl_f32_f32_f32_mk_kn_mn_instances
(
gemm_ptrs
);
//
// ck::tensor_operation::device::device_gemm_instance::
// add_device_gemm_xdl_c_shuffle_f32_f32_f32_mk_kn_mn_instances(gemm_ptrs);
...
...
@@ -189,8 +189,8 @@ void profile_gemm_impl(int do_verification,
}
else
{
//
ck::tensor_operation::device::device_gemm_instance::
//
add_device_gemm_xdl_f32_f32_f32_mk_nk_mn_instances(gemm_ptrs);
ck
::
tensor_operation
::
device
::
device_gemm_instance
::
add_device_gemm_xdl_f32_f32_f32_mk_nk_mn_instances
(
gemm_ptrs
);
//
// ck::tensor_operation::device::device_gemm_instance::
// add_device_gemm_xdl_c_shuffle_f32_f32_f32_mk_nk_mn_instances(gemm_ptrs);
...
...
@@ -207,8 +207,8 @@ void profile_gemm_impl(int do_verification,
}
else
{
//
ck::tensor_operation::device::device_gemm_instance::
//
add_device_gemm_xdl_f32_f32_f32_km_kn_mn_instances(gemm_ptrs);
ck
::
tensor_operation
::
device
::
device_gemm_instance
::
add_device_gemm_xdl_f32_f32_f32_km_kn_mn_instances
(
gemm_ptrs
);
//
// ck::tensor_operation::device::device_gemm_instance::
// add_device_gemm_xdl_c_shuffle_f32_f32_f32_km_kn_mn_instances(gemm_ptrs);
...
...
@@ -225,97 +225,97 @@ void profile_gemm_impl(int do_verification,
}
else
{
//
ck::tensor_operation::device::device_gemm_instance::
//
add_device_gemm_xdl_f32_f32_f32_km_nk_mn_instances(gemm_ptrs);
ck
::
tensor_operation
::
device
::
device_gemm_instance
::
add_device_gemm_xdl_f32_f32_f32_km_nk_mn_instances
(
gemm_ptrs
);
//
// ck::tensor_operation::device::device_gemm_instance::
// add_device_gemm_xdl_c_shuffle_f32_f32_f32_km_nk_mn_instances(gemm_ptrs);
}
}
}
//
else if constexpr(is_same<ADataType, half_t>::value && is_same<BDataType, half_t>::value &&
//
is_same<CDataType, half_t>::value)
//
{
//
if constexpr(is_same<ALayout, tensor_layout::gemm::RowMajor>::value &&
//
is_same<BLayout, tensor_layout::gemm::RowMajor>::value &&
//
is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
//
{
//
if(KBatch > 1)
//
{
else
if
constexpr
(
is_same
<
ADataType
,
half_t
>::
value
&&
is_same
<
BDataType
,
half_t
>::
value
&&
is_same
<
CDataType
,
half_t
>::
value
)
{
if
constexpr
(
is_same
<
ALayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
&&
is_same
<
BLayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
&&
is_same
<
CLayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
)
{
if
(
KBatch
>
1
)
{
// ck::tensor_operation::device::device_gemm_instance::
// add_device_gemm_xdl_splitk_c_shuffle_f16_f16_f16_mk_kn_mn_instances(gemm_ptrs);
//
}
//
else
//
{
//
ck::tensor_operation::device::device_gemm_instance::
//
add_device_gemm_xdl_f16_f16_f16_mk_kn_mn_instances(gemm_ptrs);
//
}
else
{
ck
::
tensor_operation
::
device
::
device_gemm_instance
::
add_device_gemm_xdl_f16_f16_f16_mk_kn_mn_instances
(
gemm_ptrs
);
// ck::tensor_operation::device::device_gemm_instance::
// add_device_gemm_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instances(gemm_ptrs);
//
}
//
}
//
else if constexpr(is_same<ALayout, tensor_layout::gemm::RowMajor>::value &&
//
is_same<BLayout, tensor_layout::gemm::ColumnMajor>::value &&
//
is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
//
{
//
if(KBatch > 1)
//
{
}
}
else
if
constexpr
(
is_same
<
ALayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
&&
is_same
<
BLayout
,
tensor_layout
::
gemm
::
ColumnMajor
>::
value
&&
is_same
<
CLayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
)
{
if
(
KBatch
>
1
)
{
// ck::tensor_operation::device::device_gemm_instance::
// add_device_gemm_xdl_splitk_c_shuffle_f16_f16_f16_mk_nk_mn_instances(gemm_ptrs);
//
}
//
else
//
{
//
ck::tensor_operation::device::device_gemm_instance::
//
add_device_gemm_xdl_f16_f16_f16_mk_nk_mn_instances(gemm_ptrs);
//
}
else
{
ck
::
tensor_operation
::
device
::
device_gemm_instance
::
add_device_gemm_xdl_f16_f16_f16_mk_nk_mn_instances
(
gemm_ptrs
);
// ck::tensor_operation::device::device_gemm_instance::
// add_device_gemm_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instances(gemm_ptrs);
//
// ck::tensor_operation::device::device_gemm_instance::
// add_device_gemm_xdl_c_shuffle_2_stage_f16_f16_f16_mk_nk_mn_instances(gemm_ptrs);
//
}
//
}
//
else if constexpr(is_same<ALayout, tensor_layout::gemm::ColumnMajor>::value &&
//
is_same<BLayout, tensor_layout::gemm::RowMajor>::value &&
//
is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
//
{
//
if(KBatch > 1)
//
{
}
}
else
if
constexpr
(
is_same
<
ALayout
,
tensor_layout
::
gemm
::
ColumnMajor
>::
value
&&
is_same
<
BLayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
&&
is_same
<
CLayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
)
{
if
(
KBatch
>
1
)
{
// ck::tensor_operation::device::device_gemm_instance::
// add_device_gemm_xdl_splitk_c_shuffle_f16_f16_f16_km_kn_mn_instances(gemm_ptrs);
//
}
//
else
//
{
//
ck::tensor_operation::device::device_gemm_instance::
//
add_device_gemm_xdl_f16_f16_f16_km_kn_mn_instances(gemm_ptrs);
//
}
else
{
ck
::
tensor_operation
::
device
::
device_gemm_instance
::
add_device_gemm_xdl_f16_f16_f16_km_kn_mn_instances
(
gemm_ptrs
);
// ck::tensor_operation::device::device_gemm_instance::
// add_device_gemm_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instances(gemm_ptrs);
//
}
//
}
//
else if constexpr(is_same<ALayout, tensor_layout::gemm::ColumnMajor>::value &&
//
is_same<BLayout, tensor_layout::gemm::ColumnMajor>::value &&
//
is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
//
{
//
if(KBatch > 1)
//
{
}
}
else
if
constexpr
(
is_same
<
ALayout
,
tensor_layout
::
gemm
::
ColumnMajor
>::
value
&&
is_same
<
BLayout
,
tensor_layout
::
gemm
::
ColumnMajor
>::
value
&&
is_same
<
CLayout
,
tensor_layout
::
gemm
::
RowMajor
>::
value
)
{
if
(
KBatch
>
1
)
{
// ck::tensor_operation::device::device_gemm_instance::
// add_device_gemm_xdl_splitk_c_shuffle_f16_f16_f16_km_nk_mn_instances(gemm_ptrs);
//
}
//
else
//
{
//
ck::tensor_operation::device::device_gemm_instance::
//
add_device_gemm_xdl_f16_f16_f16_km_nk_mn_instances(gemm_ptrs);
//
}
else
{
ck
::
tensor_operation
::
device
::
device_gemm_instance
::
add_device_gemm_xdl_f16_f16_f16_km_nk_mn_instances
(
gemm_ptrs
);
// ck::tensor_operation::device::device_gemm_instance::
// add_device_gemm_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instances(gemm_ptrs);
//
}
//
}
//
}
//
else if constexpr(is_same<ADataType, ck::bhalf_t>::value &&
//
is_same<BDataType, ck::bhalf_t>::value &&
//
is_same<CDataType, ck::bhalf_t>::value)
//
{
}
}
}
else
if
constexpr
(
is_same
<
ADataType
,
ck
::
bhalf_t
>::
value
&&
is_same
<
BDataType
,
ck
::
bhalf_t
>::
value
&&
is_same
<
CDataType
,
ck
::
bhalf_t
>::
value
)
{
// if constexpr(is_same<ALayout, tensor_layout::gemm::RowMajor>::value &&
// is_same<BLayout, tensor_layout::gemm::RowMajor>::value &&
// is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
...
...
@@ -344,10 +344,10 @@ void profile_gemm_impl(int do_verification,
// ck::tensor_operation::device::device_gemm_instance::
// add_device_gemm_xdl_c_shuffle_bf16_bf16_bf16_km_nk_mn_instances(gemm_ptrs);
// }
//
}
//
else if constexpr(is_same<ADataType, int8_t>::value && is_same<BDataType, int8_t>::value &&
//
is_same<CDataType, int8_t>::value)
//
{
}
else
if
constexpr
(
is_same
<
ADataType
,
int8_t
>::
value
&&
is_same
<
BDataType
,
int8_t
>::
value
&&
is_same
<
CDataType
,
int8_t
>::
value
)
{
// if constexpr(is_same<ALayout, tensor_layout::gemm::RowMajor>::value &&
// is_same<BLayout, tensor_layout::gemm::RowMajor>::value &&
// is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
...
...
@@ -376,7 +376,7 @@ void profile_gemm_impl(int do_verification,
// ck::tensor_operation::device::device_gemm_instance::
// add_device_gemm_xdl_c_shuffle_int8_int8_int8_km_nk_mn_instances(gemm_ptrs);
// }
//
}
}
if
(
gemm_ptrs
.
size
()
<=
0
)
{
...
...
test/CMakeLists.txt
View file @
328ab6f3
...
...
@@ -44,4 +44,4 @@ add_subdirectory(batched_gemm_reduce)
add_subdirectory
(
grouped_gemm
)
add_subdirectory
(
convnd_fwd
)
add_subdirectory
(
reduce
)
add_subdirectory
(
conv2d_bwd_weight
)
#
add_subdirectory(conv2d_bwd_weight)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment