Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
f0fd0263
Commit
f0fd0263
authored
Jul 21, 2023
by
Jun Liu
Browse files
Merge branch 'amd-develop' into amd-master
parents
4e911f3e
a8fafc3f
Changes
94
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
2012 additions
and
739 deletions
+2012
-739
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_gnwc_gkxc_gnwk_dl.hpp
...impl/device_grouped_conv_bwd_weight_gnwc_gkxc_gnwk_dl.hpp
+88
-78
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle.hpp
...vice/impl/device_grouped_conv_bwd_weight_xdl_cshuffle.hpp
+382
-166
include/ck/tensor_operation/operator_transform/transform_conv_bwd_data_to_gemm_v1.hpp
...operator_transform/transform_conv_bwd_data_to_gemm_v1.hpp
+719
-252
include/ck/utility/type_convert.hpp
include/ck/utility/type_convert.hpp
+0
-18
library/include/ck/library/tensor_operation_instance/gpu/convolution_backward_data.hpp
...nsor_operation_instance/gpu/convolution_backward_data.hpp
+14
-7
library/include/ck/library/tensor_operation_instance/gpu/gemm.hpp
...include/ck/library/tensor_operation_instance/gpu/gemm.hpp
+24
-22
library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_xdl_instance.hpp
...nv_bwd_data/device_grouped_conv_bwd_data_xdl_instance.hpp
+141
-0
library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_xdl_instance.hpp
...wd_weight/device_grouped_conv_bwd_weight_xdl_instance.hpp
+119
-0
library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_data.hpp
...ration_instance/gpu/grouped_convolution_backward_data.hpp
+171
-33
library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight.hpp
...tion_instance/gpu/grouped_convolution_backward_weight.hpp
+178
-46
library/src/tensor_operation_instance/gpu/CMakeLists.txt
library/src/tensor_operation_instance/gpu/CMakeLists.txt
+36
-3
library/src/tensor_operation_instance/gpu/batched_gemm_multi_d/CMakeLists.txt
...peration_instance/gpu/batched_gemm_multi_d/CMakeLists.txt
+22
-18
library/src/tensor_operation_instance/gpu/conv2d_bwd_data/CMakeLists.txt
...sor_operation_instance/gpu/conv2d_bwd_data/CMakeLists.txt
+17
-10
library/src/tensor_operation_instance/gpu/conv2d_bwd_data/device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_int8_instance.cpp
...vice_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_int8_instance.cpp
+2
-1
library/src/tensor_operation_instance/gpu/gemm/CMakeLists.txt
...ary/src/tensor_operation_instance/gpu/gemm/CMakeLists.txt
+89
-80
library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_i8_i8_i8_km_kn_mn_instance.cpp
...ce/gpu/gemm/device_gemm_dl_i8_i8_i8_km_kn_mn_instance.cpp
+2
-1
library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_i8_i8_i8_km_kn_mn_irregular_instance.cpp
...m/device_gemm_dl_i8_i8_i8_km_kn_mn_irregular_instance.cpp
+2
-1
library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_i8_i8_i8_km_nk_mn_instance.cpp
...ce/gpu/gemm/device_gemm_dl_i8_i8_i8_km_nk_mn_instance.cpp
+2
-1
library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_i8_i8_i8_km_nk_mn_irregular_instance.cpp
...m/device_gemm_dl_i8_i8_i8_km_nk_mn_irregular_instance.cpp
+2
-1
library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_i8_i8_i8_mk_kn_mn_instance.cpp
...ce/gpu/gemm/device_gemm_dl_i8_i8_i8_mk_kn_mn_instance.cpp
+2
-1
No files found.
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_gnwc_gkxc_gnwk_dl.hpp
View file @
f0fd0263
...
...
@@ -195,17 +195,17 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl
template
<
ck
::
index_t
NDim
,
typename
ck
::
enable_if
<
NDim
==
1
,
bool
>
::
type
=
false
>
static
auto
MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N
(
ck
::
index_t
N
,
ck
::
index_t
K
,
ck
::
index_t
C
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
input_spatial_lengths
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
filter_spatial_lengths
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
output_spatial_lengths
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
conv_filter_strides
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
conv_filter_dilations
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
input_left_pads
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
input_right_pads
,
ck
::
index_t
batch_k
)
const
ck
::
index_t
N
,
const
ck
::
index_t
K
,
const
ck
::
index_t
C
,
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
&
input_spatial_lengths
,
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
&
filter_spatial_lengths
,
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
&
output_spatial_lengths
,
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
&
conv_filter_strides
,
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
&
conv_filter_dilations
,
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
&
input_left_pads
,
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
&
input_right_pads
,
const
ck
::
index_t
batch_k
)
{
using
namespace
ck
;
...
...
@@ -347,17 +347,17 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl
}
// function end
template
<
ck
::
index_t
NDim
,
typename
ck
::
enable_if
<
NDim
==
2
,
bool
>
::
type
=
false
>
static
auto
MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N
(
ck
::
index_t
N
,
ck
::
index_t
K
,
ck
::
index_t
C
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
input_spatial_lengths
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
filter_spatial_lengths
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
output_spatial_lengths
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
conv_filter_strides
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
conv_filter_dilations
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
input_left_pads
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
input_right_pads
,
ck
::
index_t
batch_k
)
const
ck
::
index_t
N
,
const
ck
::
index_t
K
,
const
ck
::
index_t
C
,
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
&
input_spatial_lengths
,
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
&
filter_spatial_lengths
,
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
&
output_spatial_lengths
,
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
&
conv_filter_strides
,
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
&
conv_filter_dilations
,
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
&
input_left_pads
,
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
&
input_right_pads
,
const
ck
::
index_t
batch_k
)
{
using
namespace
ck
;
...
...
@@ -515,17 +515,17 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl
template
<
ck
::
index_t
NDim
,
typename
ck
::
enable_if
<
NDim
==
3
,
bool
>
::
type
=
false
>
static
auto
MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N
(
ck
::
index_t
N
,
ck
::
index_t
K
,
ck
::
index_t
C
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
input_spatial_lengths
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
filter_spatial_lengths
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
output_spatial_lengths
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
conv_filter_strides
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
conv_filter_dilations
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
input_left_pads
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
input_right_pads
,
ck
::
index_t
batch_k
)
const
ck
::
index_t
N
,
const
ck
::
index_t
K
,
const
ck
::
index_t
C
,
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
&
input_spatial_lengths
,
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
&
filter_spatial_lengths
,
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
&
output_spatial_lengths
,
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
&
conv_filter_strides
,
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
&
conv_filter_dilations
,
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
&
input_left_pads
,
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
&
input_right_pads
,
const
ck
::
index_t
batch_k
)
{
using
namespace
ck
;
...
...
@@ -784,17 +784,19 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl
Argument
(
const
InDataType
*
p_in_grid
,
WeiDataType
*
p_wei_grid
,
const
OutDataType
*
p_out_grid
,
ck
::
index_t
G
,
ck
::
index_t
N
,
ck
::
index_t
K
,
ck
::
index_t
C
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
input_spatial_lengths
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
filter_spatial_lengths
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
output_spatial_lengths
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
conv_filter_strides
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
conv_filter_dilations
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
input_left_pads
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
input_right_pads
,
const
ck
::
index_t
G
,
const
ck
::
index_t
N
,
const
ck
::
index_t
K
,
const
ck
::
index_t
C
,
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
>&
input_spatial_lengths
,
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
>&
filter_spatial_lengths
,
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
>&
output_spatial_lengths
,
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
+
3
>&
/*input_strides*/
,
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
+
3
>&
/*output_strides*/
,
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
>&
conv_filter_strides
,
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
>&
conv_filter_dilations
,
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
>&
input_left_pads
,
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
>&
input_right_pads
,
InElementwiseOperation
in_element_op
,
WeiElementwiseOperation
wei_element_op
,
OutElementwiseOperation
out_element_op
,
...
...
@@ -897,18 +899,18 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl
InElementwiseOperation
c_element_op_
;
// for checking IsSupportedArgument()
index_t
Conv_G_
;
index_t
Conv_N_
;
index_t
Conv_K_
;
index_t
Conv_C_
;
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
input_spatial_lengths_
;
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
filter_spatial_lengths_
;
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
output_spatial_lengths_
;
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
conv_filter_strides_
;
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
conv_filter_dilations_
;
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
input_left_pads_
;
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
input_right_pads_
;
const
index_t
Conv_G_
;
const
index_t
Conv_N_
;
const
index_t
Conv_K_
;
const
index_t
Conv_C_
;
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
&
input_spatial_lengths_
;
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
&
filter_spatial_lengths_
;
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
&
output_spatial_lengths_
;
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
&
conv_filter_strides_
;
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
&
conv_filter_dilations_
;
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
&
input_left_pads_
;
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
&
input_right_pads_
;
index_t
k_batch_
;
};
...
...
@@ -1111,17 +1113,19 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl
static
auto
MakeArgument
(
const
InDataType
*
p_in_grid
,
WeiDataType
*
p_wei_grid
,
const
OutDataType
*
p_out_grid
,
ck
::
index_t
G
,
ck
::
index_t
N
,
ck
::
index_t
K
,
ck
::
index_t
C
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
input_spatial_lengths
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
filter_spatial_lengths
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
output_spatial_lengths
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
conv_filter_strides
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
conv_filter_dilations
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
input_left_pads
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
input_right_pads
,
const
ck
::
index_t
G
,
const
ck
::
index_t
N
,
const
ck
::
index_t
K
,
const
ck
::
index_t
C
,
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
>&
input_spatial_lengths
,
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
>&
filter_spatial_lengths
,
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
>&
output_spatial_lengths
,
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
+
3
>&
input_strides
,
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
+
3
>&
output_strides
,
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
>&
conv_filter_strides
,
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
>&
conv_filter_dilations
,
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
>&
input_left_pads
,
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
>&
input_right_pads
,
InElementwiseOperation
in_element_op
,
WeiElementwiseOperation
wei_element_op
,
OutElementwiseOperation
out_element_op
,
...
...
@@ -1137,6 +1141,8 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl
input_spatial_lengths
,
filter_spatial_lengths
,
output_spatial_lengths
,
input_strides
,
output_strides
,
conv_filter_strides
,
conv_filter_dilations
,
input_left_pads
,
...
...
@@ -1153,17 +1159,19 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl
MakeArgumentPointer
(
const
void
*
p_in_grid
,
void
*
p_wei_grid
,
const
void
*
p_out_grid
,
ck
::
index_t
G
,
ck
::
index_t
N
,
ck
::
index_t
K
,
ck
::
index_t
C
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
input_spatial_lengths
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
filter_spatial_lengths
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
output_spatial_lengths
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
conv_filter_strides
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
conv_filter_dilations
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
input_left_pads
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
input_right_pads
,
const
ck
::
index_t
G
,
const
ck
::
index_t
N
,
const
ck
::
index_t
K
,
const
ck
::
index_t
C
,
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
>&
input_spatial_lengths
,
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
>&
filter_spatial_lengths
,
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
>&
output_spatial_lengths
,
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
+
3
>&
input_strides
,
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
+
3
>&
output_strides
,
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
>&
conv_filter_strides
,
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
>&
conv_filter_dilations
,
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
>&
input_left_pads
,
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
>&
input_right_pads
,
InElementwiseOperation
in_element_op
,
WeiElementwiseOperation
wei_element_op
,
OutElementwiseOperation
out_element_op
,
...
...
@@ -1179,6 +1187,8 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl
input_spatial_lengths
,
filter_spatial_lengths
,
output_spatial_lengths
,
input_strides
,
output_strides
,
conv_filter_strides
,
conv_filter_dilations
,
input_left_pads
,
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_
gnwc_gkxc_gnwk_
xdl_cshuffle.hpp
→
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle.hpp
View file @
f0fd0263
This diff is collapsed.
Click to expand it.
include/ck/tensor_operation/operator_transform/transform_conv_bwd_data_to_gemm_v1.hpp
View file @
f0fd0263
This diff is collapsed.
Click to expand it.
include/ck/utility/type_convert.hpp
View file @
f0fd0263
...
...
@@ -62,24 +62,6 @@ inline __host__ __device__ constexpr bhalf_t type_convert<bhalf_t, half_t>(half_
return
type_convert
<
bhalf_t
>
(
x_fp32
);
}
// convert bfp16 to int32 via fp32
template
<
>
inline
__host__
__device__
constexpr
int32_t
type_convert
<
int32_t
,
bhalf_t
>
(
bhalf_t
x
)
{
float
x_fp32
=
type_convert
<
float
>
(
x
);
return
static_cast
<
int32_t
>
(
x_fp32
);
}
// convert int32 to bfp16 via fp32
template
<
>
inline
__host__
__device__
constexpr
bhalf_t
type_convert
<
bhalf_t
,
int32_t
>
(
int32_t
x
)
{
float
x_fp32
=
static_cast
<
float
>
(
x
);
return
type_convert
<
bhalf_t
>
(
x_fp32
);
}
// convert bfp16 to int8 via fp32
template
<
>
inline
__host__
__device__
constexpr
int8_t
type_convert
<
int8_t
,
bhalf_t
>
(
bhalf_t
x
)
...
...
library/include/ck/library/tensor_operation_instance/gpu/convolution_backward_data.hpp
View file @
f0fd0263
...
...
@@ -39,7 +39,7 @@ void add_device_conv1d_bwd_data_xdl_nwc_kxc_nwk_f32_instances(
std
::
vector
<
std
::
unique_ptr
<
DeviceConvBwdData
<
1
,
NWC
,
KXC
,
NWK
,
F32
,
F32
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
#ifdef __int8__
void
add_device_conv1d_bwd_data_xdl_nwc_kxc_nwk_int8_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceConvBwdData
<
1
,
NWC
,
...
...
@@ -51,7 +51,7 @@ void add_device_conv1d_bwd_data_xdl_nwc_kxc_nwk_int8_instances(
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
#endif
// conv2d backward data
void
add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_bf16_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceConvBwdData
<
2
,
...
...
@@ -88,7 +88,7 @@ void add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f32_instances(
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
#ifdef __int8__
void
add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_int8_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceConvBwdData
<
2
,
NHWC
,
...
...
@@ -100,7 +100,7 @@ void add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_int8_instances(
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
#endif
// conv2d dl
void
add_device_conv2d_bwd_data_dl_nhwc_kyxc_nhwk_f16_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceConvBwdData
<
2
,
...
...
@@ -125,7 +125,7 @@ void add_device_conv2d_bwd_data_dl_nhwc_kyxc_nhwk_f32_instances(
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
#ifdef __int8__
void
add_device_conv2d_bwd_data_dl_nhwc_kyxc_nhwk_int8_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceConvBwdData
<
2
,
NHWC
,
...
...
@@ -137,6 +137,7 @@ void add_device_conv2d_bwd_data_dl_nhwc_kyxc_nhwk_int8_instances(
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
#endif
// conv3d backward data
void
add_device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_bf16_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceConvBwdData
<
3
,
...
...
@@ -173,7 +174,7 @@ void add_device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_f32_instances(
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
#ifdef __int8__
void
add_device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_int8_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceConvBwdData
<
3
,
NDHWC
,
...
...
@@ -185,7 +186,7 @@ void add_device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_int8_instances(
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
#endif
template
<
ck
::
index_t
NumDimSpatial
,
typename
InLayout
,
typename
WeiLayout
,
...
...
@@ -239,11 +240,13 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceConvBw
{
add_device_conv1d_bwd_data_xdl_nwc_kxc_nwk_bf16_instances
(
op_ptrs
);
}
#ifdef __int8__
else
if
constexpr
(
is_same_v
<
InDataType
,
int8_t
>
&&
is_same_v
<
WeiDataType
,
int8_t
>
&&
is_same_v
<
OutDataType
,
int8_t
>
)
{
add_device_conv1d_bwd_data_xdl_nwc_kxc_nwk_int8_instances
(
op_ptrs
);
}
#endif
}
else
if
constexpr
(
NumDimSpatial
==
2
&&
is_same_v
<
InLayout
,
NHWC
>
&&
is_same_v
<
WeiLayout
,
KYXC
>
&&
is_same_v
<
OutLayout
,
NHWK
>
)
...
...
@@ -266,12 +269,14 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceConvBw
{
add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_bf16_instances
(
op_ptrs
);
}
#ifdef __int8__
else
if
constexpr
(
is_same_v
<
InDataType
,
int8_t
>
&&
is_same_v
<
WeiDataType
,
int8_t
>
&&
is_same_v
<
OutDataType
,
int8_t
>
)
{
add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_int8_instances
(
op_ptrs
);
add_device_conv2d_bwd_data_dl_nhwc_kyxc_nhwk_int8_instances
(
op_ptrs
);
}
#endif
}
else
if
constexpr
(
NumDimSpatial
==
3
&&
is_same_v
<
InLayout
,
NDHWC
>
&&
is_same_v
<
WeiLayout
,
KZYXC
>
&&
is_same_v
<
OutLayout
,
NDHWK
>
)
...
...
@@ -292,11 +297,13 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceConvBw
{
add_device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_bf16_instances
(
op_ptrs
);
}
#ifdef __int8__
else
if
constexpr
(
is_same_v
<
InDataType
,
int8_t
>
&&
is_same_v
<
WeiDataType
,
int8_t
>
&&
is_same_v
<
OutDataType
,
int8_t
>
)
{
add_device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_int8_instances
(
op_ptrs
);
}
#endif
}
return
op_ptrs
;
...
...
library/include/ck/library/tensor_operation_instance/gpu/gemm.hpp
View file @
f0fd0263
...
...
@@ -77,7 +77,7 @@ void add_device_gemm_dl_f32_f32_f32_mk_nk_mn_instances(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemm
<
Row
,
Col
,
Row
,
F32
,
F32
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
#ifdef __int8__
void
add_device_gemm_dl_i8_i8_i8_km_kn_mn_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemm
<
Col
,
Row
,
Row
,
int8_t
,
int8_t
,
int8_t
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
...
...
@@ -118,6 +118,27 @@ void add_device_gemm_dl_i8_i8_i8_mk_nk_mn_irregular_instances(
DeviceGemm
<
Row
,
Col
,
Row
,
int8_t
,
int8_t
,
int8_t
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_xdl_c_shuffle_i8_i8_i8_km_kn_mn_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemm
<
Col
,
Row
,
Row
,
int8_t
,
int8_t
,
int8_t
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_xdl_c_shuffle_i8_i8_i8_km_nk_mn_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemm
<
Col
,
Col
,
Row
,
int8_t
,
int8_t
,
int8_t
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_xdl_c_shuffle_i8_i8_i8_mk_kn_mn_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemm
<
Row
,
Row
,
Row
,
int8_t
,
int8_t
,
int8_t
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_xdl_c_shuffle_i8_i8_i8_mk_nk_mn_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemm
<
Row
,
Col
,
Row
,
int8_t
,
int8_t
,
int8_t
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
#endif
void
add_device_gemm_xdl_c_shuffle_2_stage_f16_f16_f16_mk_nk_mn_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemm
<
Row
,
Col
,
Row
,
F16
,
F16
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
...
...
@@ -183,26 +204,6 @@ void add_device_gemm_xdl_c_shuffle_f32_f32_f32_mk_nk_mn_instances(
DeviceGemm
<
Row
,
Col
,
Row
,
F32
,
F32
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_xdl_c_shuffle_i8_i8_i8_km_kn_mn_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemm
<
Col
,
Row
,
Row
,
int8_t
,
int8_t
,
int8_t
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_xdl_c_shuffle_i8_i8_i8_km_nk_mn_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemm
<
Col
,
Col
,
Row
,
int8_t
,
int8_t
,
int8_t
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_xdl_c_shuffle_i8_i8_i8_mk_kn_mn_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemm
<
Row
,
Row
,
Row
,
int8_t
,
int8_t
,
int8_t
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_xdl_c_shuffle_i8_i8_i8_mk_nk_mn_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemm
<
Row
,
Col
,
Row
,
int8_t
,
int8_t
,
int8_t
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_xdl_f16_f16_f16_km_kn_mn_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemm
<
Col
,
Row
,
Row
,
F16
,
F16
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
...
...
@@ -388,6 +389,7 @@ struct DeviceOperationInstanceFactory<
add_device_gemm_xdl_c_shuffle_bf16_bf16_bf16_km_nk_mn_instances
(
op_ptrs
);
}
}
#ifdef __int8__
else
if
constexpr
(
is_same_v
<
ADataType
,
int8_t
>
&&
is_same_v
<
BDataType
,
int8_t
>
&&
is_same_v
<
CDataType
,
int8_t
>
)
{
...
...
@@ -420,7 +422,7 @@ struct DeviceOperationInstanceFactory<
add_device_gemm_dl_i8_i8_i8_km_nk_mn_irregular_instances
(
op_ptrs
);
}
}
#endif
return
op_ptrs
;
}
};
...
...
library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_xdl_instance.hpp
0 → 100644
View file @
f0fd0263
This diff is collapsed.
Click to expand it.
library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_xdl_instance.hpp
0 → 100644
View file @
f0fd0263
This diff is collapsed.
Click to expand it.
library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_data.hpp
View file @
f0fd0263
...
...
@@ -16,7 +16,7 @@ namespace device {
namespace
instance
{
// conv2d backward data
void
add_device_grouped_conv2d_bwd_data_xdl_gnhw
c
_gkyxc_gnhw
k
_f16_instances
(
void
add_device_grouped_conv2d_bwd_data_xdl_gnhw
k
_gkyxc_gnhw
c
_f16_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvBwdDataMultipleD
<
2
,
GNHWK
,
GKYXC
,
...
...
@@ -30,7 +30,7 @@ void add_device_grouped_conv2d_bwd_data_xdl_gnhwc_gkyxc_gnhwk_f16_instances(
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_grouped_conv2d_bwd_data_xdl_gnhw
c
_gkyxc_gnhw
k
_f32_instances
(
void
add_device_grouped_conv2d_bwd_data_xdl_gnhw
k
_gkyxc_gnhw
c
_f32_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvBwdDataMultipleD
<
2
,
GNHWK
,
GKYXC
,
...
...
@@ -44,7 +44,7 @@ void add_device_grouped_conv2d_bwd_data_xdl_gnhwc_gkyxc_gnhwk_f32_instances(
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_grouped_conv2d_bwd_data_xdl_gnhw
c
_gkyxc_gnhw
k
_bf16_instances
(
void
add_device_grouped_conv2d_bwd_data_xdl_gnhw
k
_gkyxc_gnhw
c
_bf16_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvBwdDataMultipleD
<
2
,
GNHWK
,
GKYXC
,
...
...
@@ -58,7 +58,7 @@ void add_device_grouped_conv2d_bwd_data_xdl_gnhwc_gkyxc_gnhwk_bf16_instances(
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_grouped_conv2d_bwd_data_xdl_nhwg
c
_gkyxc_nhwg
k
_f16_instances
(
void
add_device_grouped_conv2d_bwd_data_xdl_nhwg
k
_gkyxc_nhwg
c
_f16_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvBwdDataMultipleD
<
2
,
NHWGK
,
GKYXC
,
...
...
@@ -72,7 +72,7 @@ void add_device_grouped_conv2d_bwd_data_xdl_nhwgc_gkyxc_nhwgk_f16_instances(
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_grouped_conv2d_bwd_data_xdl_nhwg
c
_gkyxc_nhwg
k
_f32_instances
(
void
add_device_grouped_conv2d_bwd_data_xdl_nhwg
k
_gkyxc_nhwg
c
_f32_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvBwdDataMultipleD
<
2
,
NHWGK
,
GKYXC
,
...
...
@@ -86,7 +86,7 @@ void add_device_grouped_conv2d_bwd_data_xdl_nhwgc_gkyxc_nhwgk_f32_instances(
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_grouped_conv2d_bwd_data_xdl_nhwg
c
_gkyxc_nhwg
k
_bf16_instances
(
void
add_device_grouped_conv2d_bwd_data_xdl_nhwg
k
_gkyxc_nhwg
c
_bf16_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvBwdDataMultipleD
<
2
,
NHWGK
,
GKYXC
,
...
...
@@ -100,6 +100,91 @@ void add_device_grouped_conv2d_bwd_data_xdl_nhwgc_gkyxc_nhwgk_bf16_instances(
PassThrough
,
PassThrough
>>>&
instances
);
// conv3d backward data
void
add_device_grouped_conv3d_bwd_data_xdl_gndhwk_gkzyxc_gndhwc_f16_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvBwdDataMultipleD
<
3
,
GNDHWK
,
GKZYXC
,
Empty_Tuple
,
GNDHWC
,
F16
,
F16
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_grouped_conv3d_bwd_data_xdl_gndhwk_gkzyxc_gndhwc_f32_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvBwdDataMultipleD
<
3
,
GNDHWK
,
GKZYXC
,
Empty_Tuple
,
GNDHWC
,
F32
,
F32
,
Empty_Tuple
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_grouped_conv3d_bwd_data_xdl_gndhwk_gkzyxc_gndhwc_bf16_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvBwdDataMultipleD
<
3
,
GNDHWK
,
GKZYXC
,
Empty_Tuple
,
GNDHWC
,
BF16
,
BF16
,
Empty_Tuple
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_f16_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvBwdDataMultipleD
<
3
,
NDHWGK
,
GKZYXC
,
Empty_Tuple
,
NDHWGC
,
F16
,
F16
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_f32_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvBwdDataMultipleD
<
3
,
NDHWGK
,
GKZYXC
,
Empty_Tuple
,
NDHWGC
,
F32
,
F32
,
Empty_Tuple
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_bf16_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvBwdDataMultipleD
<
3
,
NDHWGK
,
GKZYXC
,
Empty_Tuple
,
NDHWGC
,
BF16
,
BF16
,
Empty_Tuple
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
template
<
ck
::
index_t
NumDimSpatial
,
typename
OutLayout
,
typename
WeiLayout
,
...
...
@@ -139,43 +224,96 @@ struct DeviceOperationInstanceFactory<
static
auto
GetInstances
()
{
std
::
vector
<
std
::
unique_ptr
<
DeviceOp
>>
op_ptrs
;
if
constexpr
(
NumDimSpatial
==
2
&&
is_same_v
<
InLayout
,
GNHWC
>
&&
is_same_v
<
WeiLayout
,
GKYXC
>
&&
is_same_v
<
OutLayout
,
GNHWK
>
)
if
constexpr
(
NumDimSpatial
==
2
)
{
if
constexpr
(
is_same_v
<
InDataType
,
F16
>
&&
is_same_v
<
WeiDataType
,
F16
>
&&
is_same_v
<
OutDataType
,
F16
>
)
{
add_device_grouped_conv2d_bwd_data_xdl_gnhwc_gkyxc_gnhwk_f16_instances
(
op_ptrs
);
}
else
if
constexpr
(
is_same_v
<
InDataType
,
F32
>
&&
is_same_v
<
WeiDataType
,
F32
>
&&
is_same_v
<
OutDataType
,
F32
>
)
if
constexpr
(
is_same_v
<
InLayout
,
GNHWC
>
&&
is_same_v
<
WeiLayout
,
GKYXC
>
&&
is_same_v
<
OutLayout
,
GNHWK
>
)
{
add_device_grouped_conv2d_bwd_data_xdl_gnhwc_gkyxc_gnhwk_f32_instances
(
op_ptrs
);
if
constexpr
(
is_same_v
<
InDataType
,
F16
>
&&
is_same_v
<
WeiDataType
,
F16
>
&&
is_same_v
<
OutDataType
,
F16
>
)
{
add_device_grouped_conv2d_bwd_data_xdl_gnhwk_gkyxc_gnhwc_f16_instances
(
op_ptrs
);
}
else
if
constexpr
(
is_same_v
<
InDataType
,
F32
>
&&
is_same_v
<
WeiDataType
,
F32
>
&&
is_same_v
<
OutDataType
,
F32
>
)
{
add_device_grouped_conv2d_bwd_data_xdl_gnhwk_gkyxc_gnhwc_f32_instances
(
op_ptrs
);
}
else
if
constexpr
(
is_same_v
<
InDataType
,
BF16
>
&&
is_same_v
<
WeiDataType
,
BF16
>
&&
is_same_v
<
OutDataType
,
BF16
>
)
{
add_device_grouped_conv2d_bwd_data_xdl_gnhwk_gkyxc_gnhwc_bf16_instances
(
op_ptrs
);
}
}
else
if
constexpr
(
is_same_v
<
In
DataType
,
BF16
>
&&
is_same_v
<
Wei
DataType
,
BF16
>
&&
is_same_v
<
Out
DataType
,
BF16
>
)
else
if
constexpr
(
is_same_v
<
In
Layout
,
NHWGC
>
&&
is_same_v
<
Wei
Layout
,
GKYXC
>
&&
is_same_v
<
Out
Layout
,
NHWGK
>
)
{
add_device_grouped_conv2d_bwd_data_xdl_gnhwc_gkyxc_gnhwk_bf16_instances
(
op_ptrs
);
if
constexpr
(
is_same_v
<
InDataType
,
F16
>
&&
is_same_v
<
WeiDataType
,
F16
>
&&
is_same_v
<
OutDataType
,
F16
>
)
{
add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f16_instances
(
op_ptrs
);
}
else
if
constexpr
(
is_same_v
<
InDataType
,
F32
>
&&
is_same_v
<
WeiDataType
,
F32
>
&&
is_same_v
<
OutDataType
,
F32
>
)
{
add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f32_instances
(
op_ptrs
);
}
else
if
constexpr
(
is_same_v
<
InDataType
,
BF16
>
&&
is_same_v
<
WeiDataType
,
BF16
>
&&
is_same_v
<
OutDataType
,
BF16
>
)
{
add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_bf16_instances
(
op_ptrs
);
}
}
}
else
if
constexpr
(
NumDimSpatial
==
2
&&
is_same_v
<
InLayout
,
NHWGC
>
&&
is_same_v
<
WeiLayout
,
GKYXC
>
&&
is_same_v
<
OutLayout
,
NHWGK
>
)
else
if
constexpr
(
NumDimSpatial
==
3
)
{
if
constexpr
(
is_same_v
<
InDataType
,
F16
>
&&
is_same_v
<
WeiDataType
,
F16
>
&&
is_same_v
<
OutDataType
,
F16
>
)
{
add_device_grouped_conv2d_bwd_data_xdl_nhwgc_gkyxc_nhwgk_f16_instances
(
op_ptrs
);
}
else
if
constexpr
(
is_same_v
<
InDataType
,
F32
>
&&
is_same_v
<
WeiDataType
,
F32
>
&&
is_same_v
<
OutDataType
,
F32
>
)
if
constexpr
(
is_same_v
<
InLayout
,
GNDHWC
>
&&
is_same_v
<
WeiLayout
,
GKZYXC
>
&&
is_same_v
<
OutLayout
,
GNDHWK
>
)
{
add_device_grouped_conv2d_bwd_data_xdl_nhwgc_gkyxc_nhwgk_f32_instances
(
op_ptrs
);
if
constexpr
(
is_same_v
<
InDataType
,
F16
>
&&
is_same_v
<
WeiDataType
,
F16
>
&&
is_same_v
<
OutDataType
,
F16
>
)
{
add_device_grouped_conv3d_bwd_data_xdl_gndhwk_gkzyxc_gndhwc_f16_instances
(
op_ptrs
);
}
else
if
constexpr
(
is_same_v
<
InDataType
,
F32
>
&&
is_same_v
<
WeiDataType
,
F32
>
&&
is_same_v
<
OutDataType
,
F32
>
)
{
add_device_grouped_conv3d_bwd_data_xdl_gndhwk_gkzyxc_gndhwc_f32_instances
(
op_ptrs
);
}
else
if
constexpr
(
is_same_v
<
InDataType
,
BF16
>
&&
is_same_v
<
WeiDataType
,
BF16
>
&&
is_same_v
<
OutDataType
,
BF16
>
)
{
add_device_grouped_conv3d_bwd_data_xdl_gndhwk_gkzyxc_gndhwc_bf16_instances
(
op_ptrs
);
}
}
else
if
constexpr
(
is_same_v
<
In
DataType
,
BF16
>
&&
is_same_v
<
Wei
DataType
,
BF16
>
&&
is_same_v
<
Out
DataType
,
BF16
>
)
else
if
constexpr
(
is_same_v
<
In
Layout
,
NDHWGC
>
&&
is_same_v
<
Wei
Layout
,
GKZYXC
>
&&
is_same_v
<
Out
Layout
,
NDHWGK
>
)
{
add_device_grouped_conv2d_bwd_data_xdl_nhwgc_gkyxc_nhwgk_bf16_instances
(
op_ptrs
);
if
constexpr
(
is_same_v
<
InDataType
,
F16
>
&&
is_same_v
<
WeiDataType
,
F16
>
&&
is_same_v
<
OutDataType
,
F16
>
)
{
add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_f16_instances
(
op_ptrs
);
}
else
if
constexpr
(
is_same_v
<
InDataType
,
F32
>
&&
is_same_v
<
WeiDataType
,
F32
>
&&
is_same_v
<
OutDataType
,
F32
>
)
{
add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_f32_instances
(
op_ptrs
);
}
else
if
constexpr
(
is_same_v
<
InDataType
,
BF16
>
&&
is_same_v
<
WeiDataType
,
BF16
>
&&
is_same_v
<
OutDataType
,
BF16
>
)
{
add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_bf16_instances
(
op_ptrs
);
}
}
}
...
...
library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight.hpp
View file @
f0fd0263
...
...
@@ -91,6 +91,42 @@ void add_device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_f32_instances(
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_bf16_f32_bf16_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvBwdWeight
<
2
,
NHWGC
,
GKYXC
,
NHWGK
,
BF16
,
F32
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f16_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvBwdWeight
<
2
,
NHWGC
,
GKYXC
,
NHWGK
,
F16
,
F16
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f32_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvBwdWeight
<
2
,
NHWGC
,
GKYXC
,
NHWGK
,
F32
,
F32
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
// conv3d backward weight
void
add_device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_bf16_f32_bf16_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvBwdWeight
<
3
,
...
...
@@ -128,6 +164,42 @@ void add_device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_f32_instances
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_f32_bf16_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvBwdWeight
<
3
,
NDHWGC
,
GKZYXC
,
NDHWGK
,
BF16
,
F32
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvBwdWeight
<
3
,
NDHWGC
,
GKZYXC
,
NDHWGK
,
F16
,
F16
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvBwdWeight
<
3
,
NDHWGC
,
GKZYXC
,
NDHWGK
,
F32
,
F32
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
template
<
ck
::
index_t
NumDimSpatial
,
typename
InLayout
,
typename
WeiLayout
,
...
...
@@ -162,66 +234,126 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
{
std
::
vector
<
std
::
unique_ptr
<
DeviceOp
>>
op_ptrs
;
if
constexpr
(
NumDimSpatial
==
1
&&
is_same_v
<
InLayout
,
GNWC
>
&&
is_same_v
<
WeiLayout
,
GKXC
>
&&
is_same_v
<
OutLayout
,
GNWK
>
)
if
constexpr
(
NumDimSpatial
==
1
)
{
if
constexpr
(
is_same_v
<
In
DataType
,
float
>
&&
is_same_v
<
Wei
DataType
,
float
>
&&
is_same_v
<
Out
DataType
,
float
>
)
if
constexpr
(
is_same_v
<
In
Layout
,
GNWC
>
&&
is_same_v
<
Wei
Layout
,
GKXC
>
&&
is_same_v
<
Out
Layout
,
GNWK
>
)
{
add_device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_f32_instances
(
op_ptrs
);
}
else
if
constexpr
(
is_same_v
<
InDataType
,
half_t
>
&&
is_same_v
<
WeiDataType
,
half_t
>
&&
is_same_v
<
OutDataType
,
half_t
>
)
{
add_device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_f16_instances
(
op_ptrs
);
}
else
if
constexpr
(
is_same_v
<
InDataType
,
ck
::
bhalf_t
>
&&
is_same_v
<
WeiDataType
,
float
>
&&
is_same_v
<
OutDataType
,
ck
::
bhalf_t
>
)
{
add_device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_bf16_f32_bf16_instances
(
op_ptrs
);
if
constexpr
(
is_same_v
<
InDataType
,
float
>
&&
is_same_v
<
WeiDataType
,
float
>
&&
is_same_v
<
OutDataType
,
float
>
)
{
add_device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_f32_instances
(
op_ptrs
);
}
else
if
constexpr
(
is_same_v
<
InDataType
,
half_t
>
&&
is_same_v
<
WeiDataType
,
half_t
>
&&
is_same_v
<
OutDataType
,
half_t
>
)
{
add_device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_f16_instances
(
op_ptrs
);
}
else
if
constexpr
(
is_same_v
<
InDataType
,
ck
::
bhalf_t
>
&&
is_same_v
<
WeiDataType
,
float
>
&&
is_same_v
<
OutDataType
,
ck
::
bhalf_t
>
)
{
add_device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_bf16_f32_bf16_instances
(
op_ptrs
);
}
}
}
else
if
constexpr
(
NumDimSpatial
==
2
&&
is_same_v
<
InLayout
,
GNHWC
>
&&
is_same_v
<
WeiLayout
,
GKYXC
>
&&
is_same_v
<
OutLayout
,
GNHWK
>
)
else
if
constexpr
(
NumDimSpatial
==
2
)
{
if
constexpr
(
is_same_v
<
In
DataType
,
float
>
&&
is_same_v
<
Wei
DataType
,
float
>
&&
is_same_v
<
Out
DataType
,
float
>
)
if
constexpr
(
is_same_v
<
In
Layout
,
GNHWC
>
&&
is_same_v
<
Wei
Layout
,
GKYXC
>
&&
is_same_v
<
Out
Layout
,
GNHWK
>
)
{
add_device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_f32_instances
(
op_ptrs
);
if
constexpr
(
is_same_v
<
InDataType
,
float
>
&&
is_same_v
<
WeiDataType
,
float
>
&&
is_same_v
<
OutDataType
,
float
>
)
{
add_device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_f32_instances
(
op_ptrs
);
}
else
if
constexpr
(
is_same_v
<
InDataType
,
half_t
>
&&
is_same_v
<
WeiDataType
,
half_t
>
&&
is_same_v
<
OutDataType
,
half_t
>
)
{
add_device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_f16_instances
(
op_ptrs
);
}
else
if
constexpr
(
is_same_v
<
InDataType
,
ck
::
bhalf_t
>
&&
is_same_v
<
WeiDataType
,
float
>
&&
is_same_v
<
OutDataType
,
ck
::
bhalf_t
>
)
{
add_device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_bf16_f32_bf16_instances
(
op_ptrs
);
}
}
else
if
constexpr
(
is_same_v
<
In
DataType
,
half_t
>
&&
is_same_v
<
Wei
DataType
,
half_t
>
&&
is_same_v
<
Out
DataType
,
half_t
>
)
else
if
constexpr
(
is_same_v
<
In
Layout
,
NHWGC
>
&&
is_same_v
<
Wei
Layout
,
GKYXC
>
&&
is_same_v
<
Out
Layout
,
NHWGK
>
)
{
add_device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_f16_instances
(
op_ptrs
);
}
else
if
constexpr
(
is_same_v
<
InDataType
,
ck
::
bhalf_t
>
&&
is_same_v
<
WeiDataType
,
float
>
&&
is_same_v
<
OutDataType
,
ck
::
bhalf_t
>
)
{
add_device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_bf16_f32_bf16_instances
(
op_ptrs
);
if
constexpr
(
is_same_v
<
InDataType
,
float
>
&&
is_same_v
<
WeiDataType
,
float
>
&&
is_same_v
<
OutDataType
,
float
>
)
{
add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f32_instances
(
op_ptrs
);
}
else
if
constexpr
(
is_same_v
<
InDataType
,
half_t
>
&&
is_same_v
<
WeiDataType
,
half_t
>
&&
is_same_v
<
OutDataType
,
half_t
>
)
{
add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f16_instances
(
op_ptrs
);
}
else
if
constexpr
(
is_same_v
<
InDataType
,
ck
::
bhalf_t
>
&&
is_same_v
<
WeiDataType
,
float
>
&&
is_same_v
<
OutDataType
,
ck
::
bhalf_t
>
)
{
add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_bf16_f32_bf16_instances
(
op_ptrs
);
}
}
}
else
if
constexpr
(
NumDimSpatial
==
3
&&
is_same_v
<
InLayout
,
GNDHWC
>
&&
is_same_v
<
WeiLayout
,
GKZYXC
>
&&
is_same_v
<
OutLayout
,
GNDHWK
>
)
else
if
constexpr
(
NumDimSpatial
==
3
)
{
if
constexpr
(
is_same_v
<
InDataType
,
float
>
&&
is_same_v
<
WeiDataType
,
float
>
&&
is_same_v
<
OutDataType
,
float
>
)
{
add_device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_f32_instances
(
op_ptrs
);
}
else
if
constexpr
(
is_same_v
<
InDataType
,
half_t
>
&&
is_same_v
<
WeiDataType
,
half_t
>
&&
is_same_v
<
OutDataType
,
half_t
>
)
if
constexpr
(
is_same_v
<
InLayout
,
GNDHWC
>
&&
is_same_v
<
WeiLayout
,
GKZYXC
>
&&
is_same_v
<
OutLayout
,
GNDHWK
>
)
{
add_device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_f16_instances
(
op_ptrs
);
if
constexpr
(
is_same_v
<
InDataType
,
float
>
&&
is_same_v
<
WeiDataType
,
float
>
&&
is_same_v
<
OutDataType
,
float
>
)
{
add_device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_f32_instances
(
op_ptrs
);
}
else
if
constexpr
(
is_same_v
<
InDataType
,
half_t
>
&&
is_same_v
<
WeiDataType
,
half_t
>
&&
is_same_v
<
OutDataType
,
half_t
>
)
{
add_device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_f16_instances
(
op_ptrs
);
}
else
if
constexpr
(
is_same_v
<
InDataType
,
ck
::
bhalf_t
>
&&
is_same_v
<
WeiDataType
,
float
>
&&
is_same_v
<
OutDataType
,
ck
::
bhalf_t
>
)
{
add_device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_bf16_f32_bf16_instances
(
op_ptrs
);
}
}
else
if
constexpr
(
is_same_v
<
In
DataType
,
ck
::
bhalf_t
>
&&
is_same_v
<
Wei
DataType
,
float
>
&&
is_same_v
<
Out
DataType
,
ck
::
bhalf_t
>
)
else
if
constexpr
(
is_same_v
<
In
Layout
,
NDHWGC
>
&&
is_same_v
<
Wei
Layout
,
GKZYXC
>
&&
is_same_v
<
Out
Layout
,
NDHWGK
>
)
{
add_device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_bf16_f32_bf16_instances
(
op_ptrs
);
if
constexpr
(
is_same_v
<
InDataType
,
float
>
&&
is_same_v
<
WeiDataType
,
float
>
&&
is_same_v
<
OutDataType
,
float
>
)
{
add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instances
(
op_ptrs
);
}
else
if
constexpr
(
is_same_v
<
InDataType
,
half_t
>
&&
is_same_v
<
WeiDataType
,
half_t
>
&&
is_same_v
<
OutDataType
,
half_t
>
)
{
add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_instances
(
op_ptrs
);
}
else
if
constexpr
(
is_same_v
<
InDataType
,
ck
::
bhalf_t
>
&&
is_same_v
<
WeiDataType
,
float
>
&&
is_same_v
<
OutDataType
,
ck
::
bhalf_t
>
)
{
add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_f32_bf16_instances
(
op_ptrs
);
}
}
}
...
...
library/src/tensor_operation_instance/gpu/CMakeLists.txt
View file @
f0fd0263
...
...
@@ -12,9 +12,42 @@ set(CK_DEVICE_INSTANCES)
FOREACH
(
subdir_path
${
dir_list
}
)
set
(
target_dir
)
IF
(
IS_DIRECTORY
"
${
subdir_path
}
"
)
get_filename_component
(
target_dir
${
subdir_path
}
NAME
)
add_subdirectory
(
${
target_dir
}
)
list
(
APPEND CK_DEVICE_INSTANCES $<TARGET_OBJECTS:device_
${
target_dir
}
_instance>
)
set
(
cmake_instance
)
file
(
READ
"
${
subdir_path
}
/CMakeLists.txt"
cmake_instance
)
set
(
add_inst 0
)
if
(
"
${
cmake_instance
}
"
MATCHES
"DTYPES MATCHES
\"
fp8
\"
"
AND DTYPES MATCHES
"fp8"
)
#message("fp8 instance found!")
set
(
add_inst 1
)
endif
()
if
(
"
${
cmake_instance
}
"
MATCHES
"DTYPES MATCHES
\"
fp16
\"
"
AND DTYPES MATCHES
"fp16"
)
#message("fp16 instance found!")
set
(
add_inst 1
)
endif
()
if
(
"
${
cmake_instance
}
"
MATCHES
"DTYPES MATCHES
\"
fp32
\"
"
AND DTYPES MATCHES
"fp32"
)
#message("fp32 instance found!")
set
(
add_inst 1
)
endif
()
if
(
"
${
cmake_instance
}
"
MATCHES
"DTYPES MATCHES
\"
fp64
\"
"
AND DTYPES MATCHES
"fp64"
)
#message("fp64 instance found!")
set
(
add_inst 1
)
endif
()
if
(
"
${
cmake_instance
}
"
MATCHES
"DTYPES MATCHES
\"
bf16
\"
"
AND DTYPES MATCHES
"bf16"
)
#message("bf16 instance found!")
set
(
add_inst 1
)
endif
()
if
(
"
${
cmake_instance
}
"
MATCHES
"DTYPES MATCHES
\"
int8
\"
"
AND DTYPES MATCHES
"int8"
)
#message("int8 instance found!")
set
(
add_inst 1
)
endif
()
if
(
NOT
"
${
cmake_instance
}
"
MATCHES
"DTYPES"
)
#message("instance should be built for all types!")
set
(
add_inst 1
)
endif
()
if
(
add_inst EQUAL 1 OR NOT DEFINED DTYPES
)
get_filename_component
(
target_dir
${
subdir_path
}
NAME
)
add_subdirectory
(
${
target_dir
}
)
list
(
APPEND CK_DEVICE_INSTANCES $<TARGET_OBJECTS:device_
${
target_dir
}
_instance>
)
endif
()
ENDIF
()
ENDFOREACH
()
...
...
library/src/tensor_operation_instance/gpu/batched_gemm_multi_d/CMakeLists.txt
View file @
f0fd0263
add_instance_library
(
device_batched_gemm_multi_d_instance
device_batched_gemm_multi_d_dl_f16_f16_f16_gmk_gkn_gmn_instance.cpp
device_batched_gemm_multi_d_dl_f16_f16_f16_gmk_gnk_gmn_instance.cpp
device_batched_gemm_multi_d_dl_f16_f16_f16_gkm_gkn_gmn_instance.cpp
device_batched_gemm_multi_d_dl_f16_f16_f16_gkm_gnk_gmn_instance.cpp
device_batched_gemm_multi_d_dl_f16_f16_f16_gmk_gkn_gmn_irregular_instance.cpp
device_batched_gemm_multi_d_dl_f16_f16_f16_gmk_gnk_gmn_irregular_instance.cpp
device_batched_gemm_multi_d_dl_f16_f16_f16_gkm_gkn_gmn_irregular_instance.cpp
device_batched_gemm_multi_d_dl_f16_f16_f16_gkm_gnk_gmn_irregular_instance.cpp
device_batched_gemm_multi_d_dl_i8_i8_i8_gmk_gkn_gmn_instance.cpp
device_batched_gemm_multi_d_dl_i8_i8_i8_gmk_gnk_gmn_instance.cpp
device_batched_gemm_multi_d_dl_i8_i8_i8_gkm_gkn_gmn_instance.cpp
device_batched_gemm_multi_d_dl_i8_i8_i8_gkm_gnk_gmn_instance.cpp
device_batched_gemm_multi_d_dl_i8_i8_i8_gmk_gkn_gmn_irregular_instance.cpp
device_batched_gemm_multi_d_dl_i8_i8_i8_gmk_gnk_gmn_irregular_instance.cpp
device_batched_gemm_multi_d_dl_i8_i8_i8_gkm_gkn_gmn_irregular_instance.cpp
device_batched_gemm_multi_d_dl_i8_i8_i8_gkm_gnk_gmn_irregular_instance.cpp
)
set
(
BATCHED_GEMM_MULTID_INSTANCES
)
if
(
DTYPES MATCHES
"fp16"
OR NOT DEFINED DTYPES
)
list
(
APPEND BATCHED_GEMM_MULTID_INSTANCES device_batched_gemm_multi_d_dl_f16_f16_f16_gmk_gkn_gmn_instance.cpp
)
list
(
APPEND BATCHED_GEMM_MULTID_INSTANCES device_batched_gemm_multi_d_dl_f16_f16_f16_gmk_gnk_gmn_instance.cpp
)
list
(
APPEND BATCHED_GEMM_MULTID_INSTANCES device_batched_gemm_multi_d_dl_f16_f16_f16_gkm_gkn_gmn_instance.cpp
)
list
(
APPEND BATCHED_GEMM_MULTID_INSTANCES device_batched_gemm_multi_d_dl_f16_f16_f16_gkm_gnk_gmn_instance.cpp
)
list
(
APPEND BATCHED_GEMM_MULTID_INSTANCES device_batched_gemm_multi_d_dl_f16_f16_f16_gmk_gkn_gmn_irregular_instance.cpp
)
list
(
APPEND BATCHED_GEMM_MULTID_INSTANCES device_batched_gemm_multi_d_dl_f16_f16_f16_gmk_gnk_gmn_irregular_instance.cpp
)
list
(
APPEND BATCHED_GEMM_MULTID_INSTANCES device_batched_gemm_multi_d_dl_f16_f16_f16_gkm_gkn_gmn_irregular_instance.cpp
)
list
(
APPEND BATCHED_GEMM_MULTID_INSTANCES device_batched_gemm_multi_d_dl_f16_f16_f16_gkm_gnk_gmn_irregular_instance.cpp
)
endif
()
if
(
DTYPES MATCHES
"int8"
OR NOT DEFINED DTYPES
)
list
(
APPEND BATCHED_GEMM_MULTID_INSTANCES device_batched_gemm_multi_d_dl_i8_i8_i8_gmk_gkn_gmn_instance.cpp
)
list
(
APPEND BATCHED_GEMM_MULTID_INSTANCES device_batched_gemm_multi_d_dl_i8_i8_i8_gmk_gnk_gmn_instance.cpp
)
list
(
APPEND BATCHED_GEMM_MULTID_INSTANCES device_batched_gemm_multi_d_dl_i8_i8_i8_gkm_gkn_gmn_instance.cpp
)
list
(
APPEND BATCHED_GEMM_MULTID_INSTANCES device_batched_gemm_multi_d_dl_i8_i8_i8_gkm_gnk_gmn_instance.cpp
)
list
(
APPEND BATCHED_GEMM_MULTID_INSTANCES device_batched_gemm_multi_d_dl_i8_i8_i8_gmk_gkn_gmn_irregular_instance.cpp
)
list
(
APPEND BATCHED_GEMM_MULTID_INSTANCES device_batched_gemm_multi_d_dl_i8_i8_i8_gmk_gnk_gmn_irregular_instance.cpp
)
list
(
APPEND BATCHED_GEMM_MULTID_INSTANCES device_batched_gemm_multi_d_dl_i8_i8_i8_gkm_gkn_gmn_irregular_instance.cpp
)
list
(
APPEND BATCHED_GEMM_MULTID_INSTANCES device_batched_gemm_multi_d_dl_i8_i8_i8_gkm_gnk_gmn_irregular_instance.cpp
)
endif
()
add_instance_library
(
device_batched_gemm_multi_d_instance
${
BATCHED_GEMM_MULTID_INSTANCES
}
)
library/src/tensor_operation_instance/gpu/conv2d_bwd_data/CMakeLists.txt
View file @
f0fd0263
add_instance_library
(
device_conv2d_bwd_data_instance
device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f32_instance.cpp
device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f16_instance.cpp
device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_bf16_instance.cpp
device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_int8_instance.cpp
device_conv2d_bwd_data_dl_nhwc_kyxc_nhwk_f16_instance.cpp
device_conv2d_bwd_data_dl_nhwc_kyxc_nhwk_f32_instance.cpp
device_conv2d_bwd_data_dl_nhwc_kyxc_nhwk_int8_instance.cpp
)
set
(
CONV2D_BWD_DATA_INSTANCES
)
if
(
DTYPES MATCHES
"fp32"
OR NOT DEFINED DTYPES
)
list
(
APPEND CONV2D_BWD_DATA_INSTANCES device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f32_instance.cpp
)
list
(
APPEND CONV2D_BWD_DATA_INSTANCES device_conv2d_bwd_data_dl_nhwc_kyxc_nhwk_f32_instance.cpp
)
endif
()
if
(
DTYPES MATCHES
"bf16"
OR NOT DEFINED DTYPES
)
list
(
APPEND CONV2D_BWD_DATA_INSTANCES device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_bf16_instance.cpp
)
endif
()
if
(
DTYPES MATCHES
"fp16"
OR NOT DEFINED DTYPES
)
list
(
APPEND CONV2D_BWD_DATA_INSTANCES device_conv2d_bwd_data_dl_nhwc_kyxc_nhwk_f16_instance.cpp
)
list
(
APPEND CONV2D_BWD_DATA_INSTANCES device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f16_instance.cpp
)
endif
()
if
(
DTYPES MATCHES
"int8"
OR NOT DEFINED DTYPES
)
list
(
APPEND CONV2D_BWD_DATA_INSTANCES device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_int8_instance.cpp
)
list
(
APPEND CONV2D_BWD_DATA_INSTANCES device_conv2d_bwd_data_dl_nhwc_kyxc_nhwk_int8_instance.cpp
)
endif
()
add_instance_library
(
device_conv2d_bwd_data_instance
${
CONV2D_BWD_DATA_INSTANCES
}
)
library/src/tensor_operation_instance/gpu/conv2d_bwd_data/device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_int8_instance.cpp
View file @
f0fd0263
...
...
@@ -11,7 +11,7 @@
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#ifdef __int8__
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
...
...
@@ -151,3 +151,4 @@ void add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_int8_instances(
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
#endif
library/src/tensor_operation_instance/gpu/gemm/CMakeLists.txt
View file @
f0fd0263
add_instance_library
(
device_gemm_instance
device_gemm_dl_f16_f16_f16_km_kn_mn_instance.cpp
device_gemm_dl_f16_f16_f16_km_kn_mn_irregular_instance.cpp
device_gemm_dl_f16_f16_f16_km_nk_mn_instance.cpp
device_gemm_dl_f16_f16_f16_km_nk_mn_irregular_instance.cpp
device_gemm_dl_f16_f16_f16_mk_kn_mn_instance.cpp
device_gemm_dl_f16_f16_f16_mk_kn_mn_irregular_instance.cpp
device_gemm_dl_f16_f16_f16_mk_nk_mn_instance.cpp
device_gemm_dl_f16_f16_f16_mk_nk_mn_irregular_instance.cpp
device_gemm_dl_f32_f32_f32_km_kn_mn_instance.cpp
device_gemm_dl_f32_f32_f32_km_nk_mn_instance.cpp
device_gemm_dl_f32_f32_f32_mk_kn_mn_instance.cpp
device_gemm_dl_f32_f32_f32_mk_nk_mn_instance.cpp
device_gemm_dl_i8_i8_i8_km_kn_mn_instance.cpp
device_gemm_dl_i8_i8_i8_km_kn_mn_irregular_instance.cpp
device_gemm_dl_i8_i8_i8_km_nk_mn_instance.cpp
device_gemm_dl_i8_i8_i8_km_nk_mn_irregular_instance.cpp
device_gemm_dl_i8_i8_i8_mk_kn_mn_instance.cpp
device_gemm_dl_i8_i8_i8_mk_kn_mn_irregular_instance.cpp
device_gemm_dl_i8_i8_i8_mk_nk_mn_instance.cpp
device_gemm_dl_i8_i8_i8_mk_nk_mn_irregular_instance.cpp
device_gemm_xdl_c_shuffle_2_stage_f16_f16_f16_mk_nk_mn_instance.cpp
device_gemm_xdl_c_shuffle_bf16_bf16_bf16_km_kn_mn_instance.cpp
device_gemm_xdl_c_shuffle_bf16_bf16_bf16_km_nk_mn_instance.cpp
device_gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_kn_mn_instance.cpp
device_gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_nk_mn_instance.cpp
device_gemm_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instance.cpp
device_gemm_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instance.cpp
device_gemm_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instance.cpp
device_gemm_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instance.cpp
device_gemm_xdl_c_shuffle_f32_f32_f32_km_kn_mn_instance.cpp
device_gemm_xdl_c_shuffle_f32_f32_f32_km_nk_mn_instance.cpp
device_gemm_xdl_c_shuffle_f32_f32_f32_mk_kn_mn_instance.cpp
device_gemm_xdl_c_shuffle_f32_f32_f32_mk_nk_mn_instance.cpp
device_gemm_xdl_c_shuffle_i8_i8_i8_km_kn_mn_instance.cpp
device_gemm_xdl_c_shuffle_i8_i8_i8_km_nk_mn_instance.cpp
device_gemm_xdl_c_shuffle_i8_i8_i8_mk_kn_mn_instance.cpp
device_gemm_xdl_c_shuffle_i8_i8_i8_mk_nk_mn_instance.cpp
set
(
GEMM_INSTANCES
)
if
(
DTYPES MATCHES
"fp64"
OR NOT DEFINED DTYPES
)
list
(
APPEND GEMM_INSTANCES device_gemm_xdl_f64_f64_f64_mk_kn_mn_instance.cpp
)
list
(
APPEND GEMM_INSTANCES device_gemm_xdl_f64_f64_f64_mk_nk_mn_instance.cpp
)
list
(
APPEND GEMM_INSTANCES device_gemm_xdl_f64_f64_f64_km_kn_mn_instance.cpp
)
list
(
APPEND GEMM_INSTANCES device_gemm_xdl_f64_f64_f64_km_nk_mn_instance.cpp
)
endif
()
if
(
DTYPES MATCHES
"fp32"
OR NOT DEFINED DTYPES
)
list
(
APPEND GEMM_INSTANCES device_gemm_xdl_f32_f32_f32_mk_kn_mn_instance.cpp
)
list
(
APPEND GEMM_INSTANCES device_gemm_xdl_f32_f32_f32_mk_nk_mn_instance.cpp
)
list
(
APPEND GEMM_INSTANCES device_gemm_xdl_f32_f32_f32_km_kn_mn_instance.cpp
)
list
(
APPEND GEMM_INSTANCES device_gemm_xdl_f32_f32_f32_km_nk_mn_instance.cpp
)
list
(
APPEND GEMM_INSTANCES device_gemm_xdl_c_shuffle_f32_f32_f32_mk_kn_mn_instance.cpp
)
list
(
APPEND GEMM_INSTANCES device_gemm_xdl_c_shuffle_f32_f32_f32_mk_nk_mn_instance.cpp
)
list
(
APPEND GEMM_INSTANCES device_gemm_xdl_c_shuffle_f32_f32_f32_km_kn_mn_instance.cpp
)
list
(
APPEND GEMM_INSTANCES device_gemm_xdl_c_shuffle_f32_f32_f32_km_nk_mn_instance.cpp
)
list
(
APPEND GEMM_INSTANCES device_gemm_dl_f32_f32_f32_mk_kn_mn_instance.cpp
)
list
(
APPEND GEMM_INSTANCES device_gemm_dl_f32_f32_f32_mk_nk_mn_instance.cpp
)
list
(
APPEND GEMM_INSTANCES device_gemm_dl_f32_f32_f32_km_kn_mn_instance.cpp
)
list
(
APPEND GEMM_INSTANCES device_gemm_dl_f32_f32_f32_km_nk_mn_instance.cpp
)
endif
()
if
(
DTYPES MATCHES
"fp16"
OR NOT DEFINED DTYPES
)
list
(
APPEND GEMM_INSTANCES device_gemm_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instance.cpp
)
list
(
APPEND GEMM_INSTANCES device_gemm_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instance.cpp
)
list
(
APPEND GEMM_INSTANCES device_gemm_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instance.cpp
)
list
(
APPEND GEMM_INSTANCES device_gemm_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instance.cpp
)
list
(
APPEND GEMM_INSTANCES device_gemm_xdl_c_shuffle_2_stage_f16_f16_f16_mk_nk_mn_instance.cpp
)
list
(
APPEND GEMM_INSTANCES device_gemm_dl_f16_f16_f16_mk_kn_mn_instance.cpp
)
list
(
APPEND GEMM_INSTANCES device_gemm_dl_f16_f16_f16_mk_kn_mn_irregular_instance.cpp
)
list
(
APPEND GEMM_INSTANCES device_gemm_dl_f16_f16_f16_mk_nk_mn_instance.cpp
)
list
(
APPEND GEMM_INSTANCES device_gemm_dl_f16_f16_f16_mk_nk_mn_irregular_instance.cpp
)
list
(
APPEND GEMM_INSTANCES device_gemm_dl_f16_f16_f16_km_kn_mn_instance.cpp
)
list
(
APPEND GEMM_INSTANCES device_gemm_dl_f16_f16_f16_km_kn_mn_irregular_instance.cpp
)
list
(
APPEND GEMM_INSTANCES device_gemm_dl_f16_f16_f16_km_nk_mn_instance.cpp
)
list
(
APPEND GEMM_INSTANCES device_gemm_dl_f16_f16_f16_km_nk_mn_irregular_instance.cpp
)
list
(
APPEND GEMM_INSTANCES device_gemm_xdl_f16_f16_f16/km_kn_mn_add_instance.cpp
)
list
(
APPEND GEMM_INSTANCES device_gemm_xdl_f16_f16_f16/km_kn_mn_default_pipeline_v1_instance.cpp
)
list
(
APPEND GEMM_INSTANCES device_gemm_xdl_f16_f16_f16/km_kn_mn_default_pipeline_v2_instance.cpp
)
list
(
APPEND GEMM_INSTANCES device_gemm_xdl_f16_f16_f16/km_kn_mn_default_pipeline_v2_opt_instance.cpp
)
list
(
APPEND GEMM_INSTANCES device_gemm_xdl_f16_f16_f16/km_kn_mn_interwave_pipeline_v1_instance.cpp
)
list
(
APPEND GEMM_INSTANCES device_gemm_xdl_f16_f16_f16/km_kn_mn_irregular_default_pipeline_v1_instance.cpp
)
list
(
APPEND GEMM_INSTANCES device_gemm_xdl_f16_f16_f16/km_kn_mn_irregular_default_pipeline_v2_instance.cpp
)
list
(
APPEND GEMM_INSTANCES device_gemm_xdl_f16_f16_f16/km_kn_mn_irregular_interwave_pipeline_v1_instance.cpp
)
list
(
APPEND GEMM_INSTANCES device_gemm_xdl_f16_f16_f16/km_nk_mn_add_instance.cpp
)
list
(
APPEND GEMM_INSTANCES device_gemm_xdl_f16_f16_f16/km_nk_mn_default_pipeline_v1_instance.cpp
)
list
(
APPEND GEMM_INSTANCES device_gemm_xdl_f16_f16_f16/km_nk_mn_default_pipeline_v2_instance.cpp
)
list
(
APPEND GEMM_INSTANCES device_gemm_xdl_f16_f16_f16/km_nk_mn_default_pipeline_v2_opt_instance.cpp
)
list
(
APPEND GEMM_INSTANCES device_gemm_xdl_f16_f16_f16/km_nk_mn_interwave_pipeline_v1_instance.cpp
)
list
(
APPEND GEMM_INSTANCES device_gemm_xdl_f16_f16_f16/km_nk_mn_irregular_default_pipeline_v1_instance.cpp
)
list
(
APPEND GEMM_INSTANCES device_gemm_xdl_f16_f16_f16/km_nk_mn_irregular_default_pipeline_v2_instance.cpp
)
list
(
APPEND GEMM_INSTANCES device_gemm_xdl_f16_f16_f16/km_nk_mn_irregular_interwave_pipeline_v1_instance.cpp
)
list
(
APPEND GEMM_INSTANCES device_gemm_xdl_f16_f16_f16/mk_kn_mn_add_instance.cpp
)
list
(
APPEND GEMM_INSTANCES device_gemm_xdl_f16_f16_f16/mk_kn_mn_default_pipeline_v1_instance.cpp
)
list
(
APPEND GEMM_INSTANCES device_gemm_xdl_f16_f16_f16/mk_kn_mn_default_pipeline_v2_instance.cpp
)
list
(
APPEND GEMM_INSTANCES device_gemm_xdl_f16_f16_f16/mk_kn_mn_default_pipeline_v2_opt_instance.cpp
)
list
(
APPEND GEMM_INSTANCES device_gemm_xdl_f16_f16_f16/mk_kn_mn_interwave_pipeline_v1_instance.cpp
)
list
(
APPEND GEMM_INSTANCES device_gemm_xdl_f16_f16_f16/mk_kn_mn_irregular_default_pipeline_v1_instance.cpp
)
list
(
APPEND GEMM_INSTANCES device_gemm_xdl_f16_f16_f16/mk_kn_mn_irregular_default_pipeline_v2_instance.cpp
)
list
(
APPEND GEMM_INSTANCES device_gemm_xdl_f16_f16_f16/mk_kn_mn_irregular_interwave_pipeline_v1_instance.cpp
)
list
(
APPEND GEMM_INSTANCES device_gemm_xdl_f16_f16_f16/mk_nk_mn_add_instance.cpp
)
list
(
APPEND GEMM_INSTANCES device_gemm_xdl_f16_f16_f16/mk_nk_mn_default_pipeline_v1_instance.cpp
)
list
(
APPEND GEMM_INSTANCES device_gemm_xdl_f16_f16_f16/mk_nk_mn_default_pipeline_v2_instance.cpp
)
list
(
APPEND GEMM_INSTANCES device_gemm_xdl_f16_f16_f16/mk_nk_mn_default_pipeline_v2_opt_instance.cpp
)
list
(
APPEND GEMM_INSTANCES device_gemm_xdl_f16_f16_f16/mk_nk_mn_interwave_pipeline_v1_instance.cpp
)
list
(
APPEND GEMM_INSTANCES device_gemm_xdl_f16_f16_f16/mk_nk_mn_irregular_default_pipeline_v1_instance.cpp
)
list
(
APPEND GEMM_INSTANCES device_gemm_xdl_f16_f16_f16/mk_nk_mn_irregular_default_pipeline_v2_instance.cpp
)
list
(
APPEND GEMM_INSTANCES device_gemm_xdl_f16_f16_f16/mk_nk_mn_irregular_interwave_pipeline_v1_instance.cpp
)
endif
()
if
(
DTYPES MATCHES
"int8"
OR NOT DEFINED DTYPES
)
list
(
APPEND GEMM_INSTANCES device_gemm_dl_i8_i8_i8_mk_kn_mn_instance.cpp
)
list
(
APPEND GEMM_INSTANCES device_gemm_dl_i8_i8_i8_mk_kn_mn_irregular_instance.cpp
)
list
(
APPEND GEMM_INSTANCES device_gemm_dl_i8_i8_i8_mk_nk_mn_instance.cpp
)
list
(
APPEND GEMM_INSTANCES device_gemm_dl_i8_i8_i8_mk_nk_mn_irregular_instance.cpp
)
list
(
APPEND GEMM_INSTANCES device_gemm_dl_i8_i8_i8_km_kn_mn_instance.cpp
)
list
(
APPEND GEMM_INSTANCES device_gemm_dl_i8_i8_i8_km_kn_mn_irregular_instance.cpp
)
list
(
APPEND GEMM_INSTANCES device_gemm_dl_i8_i8_i8_km_nk_mn_instance.cpp
)
list
(
APPEND GEMM_INSTANCES device_gemm_dl_i8_i8_i8_km_nk_mn_irregular_instance.cpp
)
list
(
APPEND GEMM_INSTANCES device_gemm_xdl_c_shuffle_i8_i8_i8_mk_kn_mn_instance.cpp
)
list
(
APPEND GEMM_INSTANCES device_gemm_xdl_c_shuffle_i8_i8_i8_mk_nk_mn_instance.cpp
)
list
(
APPEND GEMM_INSTANCES device_gemm_xdl_c_shuffle_i8_i8_i8_km_kn_mn_instance.cpp
)
list
(
APPEND GEMM_INSTANCES device_gemm_xdl_c_shuffle_i8_i8_i8_km_nk_mn_instance.cpp
)
endif
()
if
(
DTYPES MATCHES
"bf16"
OR NOT DEFINED DTYPES
)
list
(
APPEND GEMM_INSTANCES device_gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_kn_mn_instance.cpp
)
list
(
APPEND GEMM_INSTANCES device_gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_nk_mn_instance.cpp
)
list
(
APPEND GEMM_INSTANCES device_gemm_xdl_c_shuffle_bf16_bf16_bf16_km_kn_mn_instance.cpp
)
list
(
APPEND GEMM_INSTANCES device_gemm_xdl_c_shuffle_bf16_bf16_bf16_km_nk_mn_instance.cpp
)
endif
()
device_gemm_xdl_f16_f16_f16/km_kn_mn_add_instance.cpp
device_gemm_xdl_f16_f16_f16/km_kn_mn_default_pipeline_v1_instance.cpp
device_gemm_xdl_f16_f16_f16/km_kn_mn_default_pipeline_v2_instance.cpp
device_gemm_xdl_f16_f16_f16/km_kn_mn_default_pipeline_v2_opt_instance.cpp
device_gemm_xdl_f16_f16_f16/km_kn_mn_interwave_pipeline_v1_instance.cpp
device_gemm_xdl_f16_f16_f16/km_kn_mn_irregular_default_pipeline_v1_instance.cpp
device_gemm_xdl_f16_f16_f16/km_kn_mn_irregular_default_pipeline_v2_instance.cpp
device_gemm_xdl_f16_f16_f16/km_kn_mn_irregular_interwave_pipeline_v1_instance.cpp
device_gemm_xdl_f16_f16_f16/km_nk_mn_add_instance.cpp
device_gemm_xdl_f16_f16_f16/km_nk_mn_default_pipeline_v1_instance.cpp
device_gemm_xdl_f16_f16_f16/km_nk_mn_default_pipeline_v2_instance.cpp
device_gemm_xdl_f16_f16_f16/km_nk_mn_default_pipeline_v2_opt_instance.cpp
device_gemm_xdl_f16_f16_f16/km_nk_mn_interwave_pipeline_v1_instance.cpp
device_gemm_xdl_f16_f16_f16/km_nk_mn_irregular_default_pipeline_v1_instance.cpp
device_gemm_xdl_f16_f16_f16/km_nk_mn_irregular_default_pipeline_v2_instance.cpp
device_gemm_xdl_f16_f16_f16/km_nk_mn_irregular_interwave_pipeline_v1_instance.cpp
device_gemm_xdl_f16_f16_f16/mk_kn_mn_add_instance.cpp
device_gemm_xdl_f16_f16_f16/mk_kn_mn_default_pipeline_v1_instance.cpp
device_gemm_xdl_f16_f16_f16/mk_kn_mn_default_pipeline_v2_instance.cpp
device_gemm_xdl_f16_f16_f16/mk_kn_mn_default_pipeline_v2_opt_instance.cpp
device_gemm_xdl_f16_f16_f16/mk_kn_mn_interwave_pipeline_v1_instance.cpp
device_gemm_xdl_f16_f16_f16/mk_kn_mn_irregular_default_pipeline_v1_instance.cpp
device_gemm_xdl_f16_f16_f16/mk_kn_mn_irregular_default_pipeline_v2_instance.cpp
device_gemm_xdl_f16_f16_f16/mk_kn_mn_irregular_interwave_pipeline_v1_instance.cpp
device_gemm_xdl_f16_f16_f16/mk_nk_mn_add_instance.cpp
device_gemm_xdl_f16_f16_f16/mk_nk_mn_default_pipeline_v1_instance.cpp
device_gemm_xdl_f16_f16_f16/mk_nk_mn_default_pipeline_v2_instance.cpp
device_gemm_xdl_f16_f16_f16/mk_nk_mn_default_pipeline_v2_opt_instance.cpp
device_gemm_xdl_f16_f16_f16/mk_nk_mn_interwave_pipeline_v1_instance.cpp
device_gemm_xdl_f16_f16_f16/mk_nk_mn_irregular_default_pipeline_v1_instance.cpp
device_gemm_xdl_f16_f16_f16/mk_nk_mn_irregular_default_pipeline_v2_instance.cpp
device_gemm_xdl_f16_f16_f16/mk_nk_mn_irregular_interwave_pipeline_v1_instance.cpp
device_gemm_xdl_f32_f32_f32_km_kn_mn_instance.cpp
device_gemm_xdl_f32_f32_f32_km_nk_mn_instance.cpp
device_gemm_xdl_f32_f32_f32_mk_kn_mn_instance.cpp
device_gemm_xdl_f32_f32_f32_mk_nk_mn_instance.cpp
device_gemm_xdl_f64_f64_f64_km_kn_mn_instance.cpp
device_gemm_xdl_f64_f64_f64_km_nk_mn_instance.cpp
device_gemm_xdl_f64_f64_f64_mk_kn_mn_instance.cpp
device_gemm_xdl_f64_f64_f64_mk_nk_mn_instance.cpp
)
add_instance_library
(
device_gemm_instance
${
GEMM_INSTANCES
}
)
set
(
ENABLE_PIPELINE_V2_OPT OFF
)
...
...
library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_i8_i8_i8_km_kn_mn_instance.cpp
View file @
f0fd0263
...
...
@@ -8,7 +8,7 @@
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_dl.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#ifdef __int8__
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
...
...
@@ -80,3 +80,4 @@ void add_device_gemm_dl_i8_i8_i8_km_kn_mn_instances(
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
#endif
library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_i8_i8_i8_km_kn_mn_irregular_instance.cpp
View file @
f0fd0263
...
...
@@ -8,7 +8,7 @@
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_dl.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#ifdef __int8__
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
...
...
@@ -78,3 +78,4 @@ void add_device_gemm_dl_i8_i8_i8_km_kn_mn_irregular_instances(
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
#endif
library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_i8_i8_i8_km_nk_mn_instance.cpp
View file @
f0fd0263
...
...
@@ -8,7 +8,7 @@
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_dl.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#ifdef __int8__
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
...
...
@@ -80,3 +80,4 @@ void add_device_gemm_dl_i8_i8_i8_km_nk_mn_instances(
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
#endif
library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_i8_i8_i8_km_nk_mn_irregular_instance.cpp
View file @
f0fd0263
...
...
@@ -8,7 +8,7 @@
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_dl.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#ifdef __int8__
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
...
...
@@ -78,3 +78,4 @@ void add_device_gemm_dl_i8_i8_i8_km_nk_mn_irregular_instances(
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
#endif
library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_i8_i8_i8_mk_kn_mn_instance.cpp
View file @
f0fd0263
...
...
@@ -8,7 +8,7 @@
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_dl.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#ifdef __int8__
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
...
...
@@ -80,3 +80,4 @@ void add_device_gemm_dl_i8_i8_i8_mk_kn_mn_instances(
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
#endif
Prev
1
2
3
4
5
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment