Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
5683ea4e
Commit
5683ea4e
authored
Aug 08, 2023
by
Jun Liu
Browse files
Merge branch 'amd-develop' into amd-master
parents
f0831350
dddc2115
Changes
139
Expand all
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
230 additions
and
133 deletions
+230
-133
library/include/ck/library/tensor_operation_instance/gpu/batched_gemm_softmax_gemm_permute.hpp
...ration_instance/gpu/batched_gemm_softmax_gemm_permute.hpp
+8
-3
library/include/ck/library/tensor_operation_instance/gpu/contraction_bilinear.hpp
...ry/tensor_operation_instance/gpu/contraction_bilinear.hpp
+8
-6
library/include/ck/library/tensor_operation_instance/gpu/contraction_scale.hpp
...brary/tensor_operation_instance/gpu/contraction_scale.hpp
+8
-6
library/include/ck/library/tensor_operation_instance/gpu/convolution_backward_data.hpp
...nsor_operation_instance/gpu/convolution_backward_data.hpp
+60
-29
library/include/ck/library/tensor_operation_instance/gpu/convolution_forward.hpp
...ary/tensor_operation_instance/gpu/convolution_forward.hpp
+18
-8
library/include/ck/library/tensor_operation_instance/gpu/elementwise_normalization.hpp
...nsor_operation_instance/gpu/elementwise_normalization.hpp
+2
-1
library/include/ck/library/tensor_operation_instance/gpu/gemm.hpp
...include/ck/library/tensor_operation_instance/gpu/gemm.hpp
+4
-0
library/include/ck/library/tensor_operation_instance/gpu/gemm_add_relu_add_layernorm.hpp
...or_operation_instance/gpu/gemm_add_relu_add_layernorm.hpp
+2
-1
library/include/ck/library/tensor_operation_instance/gpu/gemm_bilinear.hpp
...k/library/tensor_operation_instance/gpu/gemm_bilinear.hpp
+2
-1
library/include/ck/library/tensor_operation_instance/gpu/gemm_fastgelu.hpp
...k/library/tensor_operation_instance/gpu/gemm_fastgelu.hpp
+2
-1
library/include/ck/library/tensor_operation_instance/gpu/gemm_streamk.hpp
...ck/library/tensor_operation_instance/gpu/gemm_streamk.hpp
+2
-1
library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_xdl_instance.hpp
...nv_bwd_data/device_grouped_conv_bwd_data_xdl_instance.hpp
+42
-42
library/include/ck/library/tensor_operation_instance/gpu/grouped_gemm.hpp
...ck/library/tensor_operation_instance/gpu/grouped_gemm.hpp
+2
-1
library/include/ck/library/tensor_operation_instance/gpu/normalization.hpp
...k/library/tensor_operation_instance/gpu/normalization.hpp
+10
-7
library/include/ck/library/tensor_operation_instance/gpu/pool2d_fwd.hpp
...e/ck/library/tensor_operation_instance/gpu/pool2d_fwd.hpp
+10
-7
library/include/ck/library/tensor_operation_instance/gpu/pool3d_fwd.hpp
...e/ck/library/tensor_operation_instance/gpu/pool3d_fwd.hpp
+10
-7
library/include/ck/library/tensor_operation_instance/gpu/quantization/gemm_quantization.hpp
...operation_instance/gpu/quantization/gemm_quantization.hpp
+12
-3
library/include/ck/library/tensor_operation_instance/gpu/quantization/grouped_convolution_bias_forward_perchannel_quantization.hpp
...uped_convolution_bias_forward_perchannel_quantization.hpp
+10
-3
library/include/ck/library/tensor_operation_instance/gpu/quantization/grouped_convolution_bias_forward_perlayer_quantization.hpp
...rouped_convolution_bias_forward_perlayer_quantization.hpp
+10
-3
library/include/ck/library/tensor_operation_instance/gpu/quantization/grouped_convolution_forward_perchannel_quantization.hpp
...n/grouped_convolution_forward_perchannel_quantization.hpp
+8
-3
No files found.
library/include/ck/library/tensor_operation_instance/gpu/batched_gemm_softmax_gemm_permute.hpp
View file @
5683ea4e
...
@@ -16,7 +16,7 @@ namespace ck {
...
@@ -16,7 +16,7 @@ namespace ck {
namespace
tensor_operation
{
namespace
tensor_operation
{
namespace
device
{
namespace
device
{
namespace
instance
{
namespace
instance
{
#ifdef __fp16__
void
add_device_batched_gemm_masking_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instances
(
void
add_device_batched_gemm_masking_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instances
(
std
::
vector
<
std
::
unique_ptr
<
std
::
vector
<
std
::
unique_ptr
<
DeviceBatchedGemmSoftmaxGemmPermute
<
2
,
DeviceBatchedGemmSoftmaxGemmPermute
<
2
,
...
@@ -58,7 +58,8 @@ void add_device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_g
...
@@ -58,7 +58,8 @@ void add_device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_g
PassThrough
,
PassThrough
,
MaskingSpecialization
::
MaskDisabled
>>>&
MaskingSpecialization
::
MaskDisabled
>>>&
instances
);
instances
);
#endif
#ifdef __bf16__
void
add_device_batched_gemm_masking_softmax_gemm_permute_xdl_cshuffle_bf16_bf16_bf16_bf16_gmk_gnk_gno_gmo_instances
(
void
add_device_batched_gemm_masking_softmax_gemm_permute_xdl_cshuffle_bf16_bf16_bf16_bf16_gmk_gnk_gno_gmo_instances
(
std
::
vector
<
std
::
unique_ptr
<
std
::
vector
<
std
::
unique_ptr
<
DeviceBatchedGemmSoftmaxGemmPermute
<
2
,
DeviceBatchedGemmSoftmaxGemmPermute
<
2
,
...
@@ -100,6 +101,7 @@ void add_device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_bf16_bf16_bf16_bf
...
@@ -100,6 +101,7 @@ void add_device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_bf16_bf16_bf16_bf
PassThrough
,
PassThrough
,
MaskingSpecialization
::
MaskDisabled
>>>&
MaskingSpecialization
::
MaskDisabled
>>>&
instances
);
instances
);
#endif
template
<
typename
ADataType
,
template
<
typename
ADataType
,
typename
B0DataType
,
typename
B0DataType
,
...
@@ -146,7 +148,7 @@ struct DeviceOperationInstanceFactory<
...
@@ -146,7 +148,7 @@ struct DeviceOperationInstanceFactory<
static
auto
GetInstances
()
static
auto
GetInstances
()
{
{
std
::
vector
<
std
::
unique_ptr
<
DeviceOp
>>
op_ptrs
;
std
::
vector
<
std
::
unique_ptr
<
DeviceOp
>>
op_ptrs
;
#ifdef __fp16__
if
constexpr
(
is_same_v
<
ADataType
,
half_t
>
&&
is_same_v
<
B0DataType
,
half_t
>
&&
if
constexpr
(
is_same_v
<
ADataType
,
half_t
>
&&
is_same_v
<
B0DataType
,
half_t
>
&&
is_same_v
<
B1DataType
,
half_t
>
&&
is_same_v
<
CDataType
,
half_t
>
)
is_same_v
<
B1DataType
,
half_t
>
&&
is_same_v
<
CDataType
,
half_t
>
)
{
{
...
@@ -161,6 +163,8 @@ struct DeviceOperationInstanceFactory<
...
@@ -161,6 +163,8 @@ struct DeviceOperationInstanceFactory<
op_ptrs
);
op_ptrs
);
}
}
}
}
#endif
#ifdef __bf16__
else
if
constexpr
(
is_same_v
<
ADataType
,
BF16
>
&&
is_same_v
<
B0DataType
,
BF16
>
&&
else
if
constexpr
(
is_same_v
<
ADataType
,
BF16
>
&&
is_same_v
<
B0DataType
,
BF16
>
&&
is_same_v
<
B1DataType
,
BF16
>
&&
is_same_v
<
CDataType
,
BF16
>
)
is_same_v
<
B1DataType
,
BF16
>
&&
is_same_v
<
CDataType
,
BF16
>
)
{
{
...
@@ -175,6 +179,7 @@ struct DeviceOperationInstanceFactory<
...
@@ -175,6 +179,7 @@ struct DeviceOperationInstanceFactory<
op_ptrs
);
op_ptrs
);
}
}
}
}
#endif
return
op_ptrs
;
return
op_ptrs
;
}
}
};
};
...
...
library/include/ck/library/tensor_operation_instance/gpu/contraction_bilinear.hpp
View file @
5683ea4e
...
@@ -16,7 +16,7 @@ namespace ck {
...
@@ -16,7 +16,7 @@ namespace ck {
namespace
tensor_operation
{
namespace
tensor_operation
{
namespace
device
{
namespace
device
{
namespace
instance
{
namespace
instance
{
#ifdef __fp32__
// float
// float
void
add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_kknn_instance
(
void
add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_kknn_instance
(
std
::
vector
<
std
::
unique_ptr
<
DeviceContractionMultipleD
<
2
,
std
::
vector
<
std
::
unique_ptr
<
DeviceContractionMultipleD
<
2
,
...
@@ -65,7 +65,8 @@ void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_mnnn
...
@@ -65,7 +65,8 @@ void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_mnnn
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
Bilinear
>>>&
instances
);
Bilinear
>>>&
instances
);
#endif
#ifdef __fp64__
// double
// double
void
add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_kknn_instance
(
void
add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_kknn_instance
(
std
::
vector
<
std
::
unique_ptr
<
DeviceContractionMultipleD
<
2
,
std
::
vector
<
std
::
unique_ptr
<
DeviceContractionMultipleD
<
2
,
...
@@ -114,7 +115,7 @@ void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_mnnn
...
@@ -114,7 +115,7 @@ void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_mnnn
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
Bilinear
>>>&
instances
);
Bilinear
>>>&
instances
);
#endif
// Contraction + Bilinear
// Contraction + Bilinear
template
<
index_t
NumDimM
,
template
<
index_t
NumDimM
,
index_t
NumDimN
,
index_t
NumDimN
,
...
@@ -149,7 +150,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceContra
...
@@ -149,7 +150,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceContra
static
auto
GetInstances
()
static
auto
GetInstances
()
{
{
std
::
vector
<
std
::
unique_ptr
<
DeviceOp
>>
op_ptrs
;
std
::
vector
<
std
::
unique_ptr
<
DeviceOp
>>
op_ptrs
;
#ifdef __fp32__
if
constexpr
(
is_same_v
<
ADataType
,
float
>
&&
is_same_v
<
BDataType
,
float
>
&&
if
constexpr
(
is_same_v
<
ADataType
,
float
>
&&
is_same_v
<
BDataType
,
float
>
&&
is_same_v
<
DDataType
,
float
>
&&
is_same_v
<
EDataType
,
float
>
)
is_same_v
<
DDataType
,
float
>
&&
is_same_v
<
EDataType
,
float
>
)
{
{
...
@@ -165,7 +166,8 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceContra
...
@@ -165,7 +166,8 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceContra
op_ptrs
);
op_ptrs
);
}
}
}
}
#endif
#ifdef __fp64__
if
constexpr
(
is_same_v
<
ADataType
,
double
>
&&
is_same_v
<
BDataType
,
double
>
&&
if
constexpr
(
is_same_v
<
ADataType
,
double
>
&&
is_same_v
<
BDataType
,
double
>
&&
is_same_v
<
DDataType
,
double
>
&&
is_same_v
<
EDataType
,
double
>
)
is_same_v
<
DDataType
,
double
>
&&
is_same_v
<
EDataType
,
double
>
)
{
{
...
@@ -181,7 +183,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceContra
...
@@ -181,7 +183,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceContra
op_ptrs
);
op_ptrs
);
}
}
}
}
#endif
return
op_ptrs
;
return
op_ptrs
;
}
}
};
};
...
...
library/include/ck/library/tensor_operation_instance/gpu/contraction_scale.hpp
View file @
5683ea4e
...
@@ -16,7 +16,7 @@ namespace ck {
...
@@ -16,7 +16,7 @@ namespace ck {
namespace
tensor_operation
{
namespace
tensor_operation
{
namespace
device
{
namespace
device
{
namespace
instance
{
namespace
instance
{
#ifdef __fp32__
// float
// float
void
add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_kkn_instance
(
void
add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_kkn_instance
(
std
::
vector
<
std
::
unique_ptr
<
DeviceContractionMultipleD
<
2
,
std
::
vector
<
std
::
unique_ptr
<
DeviceContractionMultipleD
<
2
,
...
@@ -65,7 +65,8 @@ void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_mnn_instanc
...
@@ -65,7 +65,8 @@ void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_mnn_instanc
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
Scale
>>>&
instances
);
Scale
>>>&
instances
);
#endif
#ifdef __fp64__
// double
// double
void
add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_kkn_instance
(
void
add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_kkn_instance
(
std
::
vector
<
std
::
unique_ptr
<
DeviceContractionMultipleD
<
2
,
std
::
vector
<
std
::
unique_ptr
<
DeviceContractionMultipleD
<
2
,
...
@@ -114,7 +115,7 @@ void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_mnn_instanc
...
@@ -114,7 +115,7 @@ void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_mnn_instanc
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
Scale
>>>&
instances
);
Scale
>>>&
instances
);
#endif
// Contraction + Scale
// Contraction + Scale
template
<
index_t
NumDimM
,
template
<
index_t
NumDimM
,
index_t
NumDimN
,
index_t
NumDimN
,
...
@@ -148,7 +149,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceContra
...
@@ -148,7 +149,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceContra
static
auto
GetInstances
()
static
auto
GetInstances
()
{
{
std
::
vector
<
std
::
unique_ptr
<
DeviceOp
>>
op_ptrs
;
std
::
vector
<
std
::
unique_ptr
<
DeviceOp
>>
op_ptrs
;
#ifdef __fp32__
if
constexpr
(
is_same_v
<
ADataType
,
float
>
&&
is_same_v
<
BDataType
,
float
>
&&
if
constexpr
(
is_same_v
<
ADataType
,
float
>
&&
is_same_v
<
BDataType
,
float
>
&&
is_same_v
<
EDataType
,
float
>
)
is_same_v
<
EDataType
,
float
>
)
{
{
...
@@ -164,7 +165,8 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceContra
...
@@ -164,7 +165,8 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceContra
op_ptrs
);
op_ptrs
);
}
}
}
}
#endif
#ifdef __fp64__
if
constexpr
(
is_same_v
<
ADataType
,
double
>
&&
is_same_v
<
BDataType
,
double
>
&&
if
constexpr
(
is_same_v
<
ADataType
,
double
>
&&
is_same_v
<
BDataType
,
double
>
&&
is_same_v
<
EDataType
,
double
>
)
is_same_v
<
EDataType
,
double
>
)
{
{
...
@@ -180,7 +182,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceContra
...
@@ -180,7 +182,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceContra
op_ptrs
);
op_ptrs
);
}
}
}
}
#endif
return
op_ptrs
;
return
op_ptrs
;
}
}
};
};
...
...
library/include/ck/library/tensor_operation_instance/gpu/convolution_backward_data.hpp
View file @
5683ea4e
...
@@ -16,7 +16,7 @@ namespace ck {
...
@@ -16,7 +16,7 @@ namespace ck {
namespace
tensor_operation
{
namespace
tensor_operation
{
namespace
device
{
namespace
device
{
namespace
instance
{
namespace
instance
{
#ifdef __bf16__
// conv1d backward data
// conv1d backward data
void
add_device_conv1d_bwd_data_xdl_nwc_kxc_nwk_bf16_instances
(
void
add_device_conv1d_bwd_data_xdl_nwc_kxc_nwk_bf16_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceConvBwdData
<
1
,
std
::
vector
<
std
::
unique_ptr
<
DeviceConvBwdData
<
1
,
...
@@ -29,16 +29,19 @@ void add_device_conv1d_bwd_data_xdl_nwc_kxc_nwk_bf16_instances(
...
@@ -29,16 +29,19 @@ void add_device_conv1d_bwd_data_xdl_nwc_kxc_nwk_bf16_instances(
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
PassThrough
>>>&
instances
);
#endif
#ifdef __fp16__
void
add_device_conv1d_bwd_data_xdl_nwc_kxc_nwk_f16_instances
(
void
add_device_conv1d_bwd_data_xdl_nwc_kxc_nwk_f16_instances
(
std
::
vector
<
std
::
unique_ptr
<
std
::
vector
<
std
::
unique_ptr
<
DeviceConvBwdData
<
1
,
NWC
,
KXC
,
NWK
,
F16
,
F16
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
DeviceConvBwdData
<
1
,
NWC
,
KXC
,
NWK
,
F16
,
F16
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
instances
);
#endif
#ifdef __fp32__
void
add_device_conv1d_bwd_data_xdl_nwc_kxc_nwk_f32_instances
(
void
add_device_conv1d_bwd_data_xdl_nwc_kxc_nwk_f32_instances
(
std
::
vector
<
std
::
unique_ptr
<
std
::
vector
<
std
::
unique_ptr
<
DeviceConvBwdData
<
1
,
NWC
,
KXC
,
NWK
,
F32
,
F32
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
DeviceConvBwdData
<
1
,
NWC
,
KXC
,
NWK
,
F32
,
F32
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
instances
);
#endif
#ifdef __int8__
#ifdef __int8__
void
add_device_conv1d_bwd_data_xdl_nwc_kxc_nwk_int8_instances
(
void
add_device_conv1d_bwd_data_xdl_nwc_kxc_nwk_int8_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceConvBwdData
<
1
,
std
::
vector
<
std
::
unique_ptr
<
DeviceConvBwdData
<
1
,
...
@@ -52,6 +55,7 @@ void add_device_conv1d_bwd_data_xdl_nwc_kxc_nwk_int8_instances(
...
@@ -52,6 +55,7 @@ void add_device_conv1d_bwd_data_xdl_nwc_kxc_nwk_int8_instances(
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
PassThrough
>>>&
instances
);
#endif
#endif
#ifdef __bf16__
// conv2d backward data
// conv2d backward data
void
add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_bf16_instances
(
void
add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_bf16_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceConvBwdData
<
2
,
std
::
vector
<
std
::
unique_ptr
<
DeviceConvBwdData
<
2
,
...
@@ -64,7 +68,8 @@ void add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_bf16_instances(
...
@@ -64,7 +68,8 @@ void add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_bf16_instances(
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
PassThrough
>>>&
instances
);
#endif
#ifdef __fp16__
void
add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f16_instances
(
void
add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f16_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceConvBwdData
<
2
,
std
::
vector
<
std
::
unique_ptr
<
DeviceConvBwdData
<
2
,
NHWC
,
NHWC
,
...
@@ -76,7 +81,8 @@ void add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f16_instances(
...
@@ -76,7 +81,8 @@ void add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f16_instances(
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
PassThrough
>>>&
instances
);
#endif
#ifdef __fp32__
void
add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f32_instances
(
void
add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f32_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceConvBwdData
<
2
,
std
::
vector
<
std
::
unique_ptr
<
DeviceConvBwdData
<
2
,
NHWC
,
NHWC
,
...
@@ -88,6 +94,7 @@ void add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f32_instances(
...
@@ -88,6 +94,7 @@ void add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f32_instances(
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
PassThrough
>>>&
instances
);
#endif
#ifdef __int8__
#ifdef __int8__
void
add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_int8_instances
(
void
add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_int8_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceConvBwdData
<
2
,
std
::
vector
<
std
::
unique_ptr
<
DeviceConvBwdData
<
2
,
...
@@ -101,6 +108,8 @@ void add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_int8_instances(
...
@@ -101,6 +108,8 @@ void add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_int8_instances(
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
PassThrough
>>>&
instances
);
#endif
#endif
#ifdef DL_KERNELS
#ifdef __fp16__
// conv2d dl
// conv2d dl
void
add_device_conv2d_bwd_data_dl_nhwc_kyxc_nhwk_f16_instances
(
void
add_device_conv2d_bwd_data_dl_nhwc_kyxc_nhwk_f16_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceConvBwdData
<
2
,
std
::
vector
<
std
::
unique_ptr
<
DeviceConvBwdData
<
2
,
...
@@ -113,7 +122,8 @@ void add_device_conv2d_bwd_data_dl_nhwc_kyxc_nhwk_f16_instances(
...
@@ -113,7 +122,8 @@ void add_device_conv2d_bwd_data_dl_nhwc_kyxc_nhwk_f16_instances(
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
PassThrough
>>>&
instances
);
#endif
#ifdef __fp32__
void
add_device_conv2d_bwd_data_dl_nhwc_kyxc_nhwk_f32_instances
(
void
add_device_conv2d_bwd_data_dl_nhwc_kyxc_nhwk_f32_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceConvBwdData
<
2
,
std
::
vector
<
std
::
unique_ptr
<
DeviceConvBwdData
<
2
,
NHWC
,
NHWC
,
...
@@ -125,6 +135,7 @@ void add_device_conv2d_bwd_data_dl_nhwc_kyxc_nhwk_f32_instances(
...
@@ -125,6 +135,7 @@ void add_device_conv2d_bwd_data_dl_nhwc_kyxc_nhwk_f32_instances(
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
PassThrough
>>>&
instances
);
#endif
#ifdef __int8__
#ifdef __int8__
void
add_device_conv2d_bwd_data_dl_nhwc_kyxc_nhwk_int8_instances
(
void
add_device_conv2d_bwd_data_dl_nhwc_kyxc_nhwk_int8_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceConvBwdData
<
2
,
std
::
vector
<
std
::
unique_ptr
<
DeviceConvBwdData
<
2
,
...
@@ -138,6 +149,8 @@ void add_device_conv2d_bwd_data_dl_nhwc_kyxc_nhwk_int8_instances(
...
@@ -138,6 +149,8 @@ void add_device_conv2d_bwd_data_dl_nhwc_kyxc_nhwk_int8_instances(
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
PassThrough
>>>&
instances
);
#endif
#endif
#endif
#ifdef __bf16__
// conv3d backward data
// conv3d backward data
void
add_device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_bf16_instances
(
void
add_device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_bf16_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceConvBwdData
<
3
,
std
::
vector
<
std
::
unique_ptr
<
DeviceConvBwdData
<
3
,
...
@@ -150,7 +163,8 @@ void add_device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_bf16_instances(
...
@@ -150,7 +163,8 @@ void add_device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_bf16_instances(
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
PassThrough
>>>&
instances
);
#endif
#ifdef __fp16__
void
add_device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_f16_instances
(
void
add_device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_f16_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceConvBwdData
<
3
,
std
::
vector
<
std
::
unique_ptr
<
DeviceConvBwdData
<
3
,
NDHWC
,
NDHWC
,
...
@@ -162,7 +176,8 @@ void add_device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_f16_instances(
...
@@ -162,7 +176,8 @@ void add_device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_f16_instances(
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
PassThrough
>>>&
instances
);
#endif
#ifdef __fp32__
void
add_device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_f32_instances
(
void
add_device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_f32_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceConvBwdData
<
3
,
std
::
vector
<
std
::
unique_ptr
<
DeviceConvBwdData
<
3
,
NDHWC
,
NDHWC
,
...
@@ -174,6 +189,7 @@ void add_device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_f32_instances(
...
@@ -174,6 +189,7 @@ void add_device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_f32_instances(
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
PassThrough
>>>&
instances
);
#endif
#ifdef __int8__
#ifdef __int8__
void
add_device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_int8_instances
(
void
add_device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_int8_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceConvBwdData
<
3
,
std
::
vector
<
std
::
unique_ptr
<
DeviceConvBwdData
<
3
,
...
@@ -229,19 +245,22 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceConvBw
...
@@ -229,19 +245,22 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceConvBw
{
{
add_device_conv1d_bwd_data_xdl_nwc_kxc_nwk_f32_instances
(
op_ptrs
);
add_device_conv1d_bwd_data_xdl_nwc_kxc_nwk_f32_instances
(
op_ptrs
);
}
}
else
if
constexpr
(
is_same_v
<
InDataType
,
half_t
>
&&
is_same_v
<
WeiDataType
,
half_t
>
&&
#ifdef __fp16__
if
constexpr
(
is_same_v
<
InDataType
,
half_t
>
&&
is_same_v
<
WeiDataType
,
half_t
>
&&
is_same_v
<
OutDataType
,
half_t
>
)
is_same_v
<
OutDataType
,
half_t
>
)
{
{
add_device_conv1d_bwd_data_xdl_nwc_kxc_nwk_f16_instances
(
op_ptrs
);
add_device_conv1d_bwd_data_xdl_nwc_kxc_nwk_f16_instances
(
op_ptrs
);
}
}
else
if
constexpr
(
is_same_v
<
InDataType
,
ck
::
bhalf_t
>
&&
#endif
is_same_v
<
WeiDataType
,
ck
::
bhalf_t
>
&&
#ifdef __bf16__
is_same_v
<
OutDataType
,
ck
::
bhalf_t
>
)
if
constexpr
(
is_same_v
<
InDataType
,
ck
::
bhalf_t
>
&&
is_same_v
<
WeiDataType
,
ck
::
bhalf_t
>
&&
is_same_v
<
OutDataType
,
ck
::
bhalf_t
>
)
{
{
add_device_conv1d_bwd_data_xdl_nwc_kxc_nwk_bf16_instances
(
op_ptrs
);
add_device_conv1d_bwd_data_xdl_nwc_kxc_nwk_bf16_instances
(
op_ptrs
);
}
}
#endif
#ifdef __int8__
#ifdef __int8__
else
if
constexpr
(
is_same_v
<
InDataType
,
int8_t
>
&&
is_same_v
<
WeiDataType
,
int8_t
>
&&
if
constexpr
(
is_same_v
<
InDataType
,
int8_t
>
&&
is_same_v
<
WeiDataType
,
int8_t
>
&&
is_same_v
<
OutDataType
,
int8_t
>
)
is_same_v
<
OutDataType
,
int8_t
>
)
{
{
add_device_conv1d_bwd_data_xdl_nwc_kxc_nwk_int8_instances
(
op_ptrs
);
add_device_conv1d_bwd_data_xdl_nwc_kxc_nwk_int8_instances
(
op_ptrs
);
...
@@ -255,26 +274,35 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceConvBw
...
@@ -255,26 +274,35 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceConvBw
is_same_v
<
OutDataType
,
float
>
)
is_same_v
<
OutDataType
,
float
>
)
{
{
add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f32_instances
(
op_ptrs
);
add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f32_instances
(
op_ptrs
);
#ifdef DL_KERNELS
add_device_conv2d_bwd_data_dl_nhwc_kyxc_nhwk_f32_instances
(
op_ptrs
);
add_device_conv2d_bwd_data_dl_nhwc_kyxc_nhwk_f32_instances
(
op_ptrs
);
#endif
}
}
else
if
constexpr
(
is_same_v
<
InDataType
,
half_t
>
&&
is_same_v
<
WeiDataType
,
half_t
>
&&
#ifdef __fp16__
if
constexpr
(
is_same_v
<
InDataType
,
half_t
>
&&
is_same_v
<
WeiDataType
,
half_t
>
&&
is_same_v
<
OutDataType
,
half_t
>
)
is_same_v
<
OutDataType
,
half_t
>
)
{
{
add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f16_instances
(
op_ptrs
);
add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f16_instances
(
op_ptrs
);
#ifdef DL_KERNELS
add_device_conv2d_bwd_data_dl_nhwc_kyxc_nhwk_f16_instances
(
op_ptrs
);
add_device_conv2d_bwd_data_dl_nhwc_kyxc_nhwk_f16_instances
(
op_ptrs
);
#endif
}
}
else
if
constexpr
(
is_same_v
<
InDataType
,
ck
::
bhalf_t
>
&&
#endif
is_same_v
<
WeiDataType
,
ck
::
bhalf_t
>
&&
#ifdef __bf16__
is_same_v
<
OutDataType
,
ck
::
bhalf_t
>
)
if
constexpr
(
is_same_v
<
InDataType
,
ck
::
bhalf_t
>
&&
is_same_v
<
WeiDataType
,
ck
::
bhalf_t
>
&&
is_same_v
<
OutDataType
,
ck
::
bhalf_t
>
)
{
{
add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_bf16_instances
(
op_ptrs
);
add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_bf16_instances
(
op_ptrs
);
}
}
#endif
#ifdef __int8__
#ifdef __int8__
else
if
constexpr
(
is_same_v
<
InDataType
,
int8_t
>
&&
is_same_v
<
WeiDataType
,
int8_t
>
&&
if
constexpr
(
is_same_v
<
InDataType
,
int8_t
>
&&
is_same_v
<
WeiDataType
,
int8_t
>
&&
is_same_v
<
OutDataType
,
int8_t
>
)
is_same_v
<
OutDataType
,
int8_t
>
)
{
{
add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_int8_instances
(
op_ptrs
);
add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_int8_instances
(
op_ptrs
);
#ifdef DL_KERNELS
add_device_conv2d_bwd_data_dl_nhwc_kyxc_nhwk_int8_instances
(
op_ptrs
);
add_device_conv2d_bwd_data_dl_nhwc_kyxc_nhwk_int8_instances
(
op_ptrs
);
#endif
}
}
#endif
#endif
}
}
...
@@ -286,19 +314,22 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceConvBw
...
@@ -286,19 +314,22 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceConvBw
{
{
add_device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_f32_instances
(
op_ptrs
);
add_device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_f32_instances
(
op_ptrs
);
}
}
else
if
constexpr
(
is_same_v
<
InDataType
,
half_t
>
&&
is_same_v
<
WeiDataType
,
half_t
>
&&
#ifdef __fp16__
if
constexpr
(
is_same_v
<
InDataType
,
half_t
>
&&
is_same_v
<
WeiDataType
,
half_t
>
&&
is_same_v
<
OutDataType
,
half_t
>
)
is_same_v
<
OutDataType
,
half_t
>
)
{
{
add_device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_f16_instances
(
op_ptrs
);
add_device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_f16_instances
(
op_ptrs
);
}
}
else
if
constexpr
(
is_same_v
<
InDataType
,
ck
::
bhalf_t
>
&&
#endif
is_same_v
<
WeiDataType
,
ck
::
bhalf_t
>
&&
#ifdef __bf16__
is_same_v
<
OutDataType
,
ck
::
bhalf_t
>
)
if
constexpr
(
is_same_v
<
InDataType
,
ck
::
bhalf_t
>
&&
is_same_v
<
WeiDataType
,
ck
::
bhalf_t
>
&&
is_same_v
<
OutDataType
,
ck
::
bhalf_t
>
)
{
{
add_device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_bf16_instances
(
op_ptrs
);
add_device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_bf16_instances
(
op_ptrs
);
}
}
#endif
#ifdef __int8__
#ifdef __int8__
else
if
constexpr
(
is_same_v
<
InDataType
,
int8_t
>
&&
is_same_v
<
WeiDataType
,
int8_t
>
&&
if
constexpr
(
is_same_v
<
InDataType
,
int8_t
>
&&
is_same_v
<
WeiDataType
,
int8_t
>
&&
is_same_v
<
OutDataType
,
int8_t
>
)
is_same_v
<
OutDataType
,
int8_t
>
)
{
{
add_device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_int8_instances
(
op_ptrs
);
add_device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_int8_instances
(
op_ptrs
);
...
...
library/include/ck/library/tensor_operation_instance/gpu/convolution_forward.hpp
View file @
5683ea4e
...
@@ -18,11 +18,17 @@ namespace device {
...
@@ -18,11 +18,17 @@ namespace device {
namespace
instance
{
namespace
instance
{
// conv2d forward
// conv2d forward
#ifdef __fp16__
void
add_device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk_f16_instances
(
void
add_device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk_f16_instances
(
std
::
vector
<
std
::
unique_ptr
<
std
::
vector
<
std
::
unique_ptr
<
DeviceConvFwd
<
2
,
NHWC
,
KYXC
,
NHWK
,
F16
,
F16
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
DeviceConvFwd
<
2
,
NHWC
,
KYXC
,
NHWK
,
F16
,
F16
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
instances
);
void
add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f16_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceConvFwd
<
2
,
NHWC
,
KYXC
,
NHWK
,
F16
,
F16
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
#endif
#ifdef __bf16__
void
add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_bf16_instances
(
void
add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_bf16_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceConvFwd
<
2
,
std
::
vector
<
std
::
unique_ptr
<
DeviceConvFwd
<
2
,
NHWC
,
NHWC
,
...
@@ -34,17 +40,14 @@ void add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_bf16_instances(
...
@@ -34,17 +40,14 @@ void add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_bf16_instances(
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
PassThrough
>>>&
instances
);
#endif
void
add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f16_instances
(
#ifdef __fp32__
std
::
vector
<
std
::
unique_ptr
<
DeviceConvFwd
<
2
,
NHWC
,
KYXC
,
NHWK
,
F16
,
F16
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f32_instances
(
void
add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f32_instances
(
std
::
vector
<
std
::
unique_ptr
<
std
::
vector
<
std
::
unique_ptr
<
DeviceConvFwd
<
2
,
NHWC
,
KYXC
,
NHWK
,
F32
,
F32
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
DeviceConvFwd
<
2
,
NHWC
,
KYXC
,
NHWK
,
F32
,
F32
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
instances
);
#endif
#ifdef __int8__
void
add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_int8_instances
(
void
add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_int8_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceConvFwd
<
2
,
std
::
vector
<
std
::
unique_ptr
<
DeviceConvFwd
<
2
,
NHWC
,
NHWC
,
...
@@ -56,6 +59,7 @@ void add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_int8_instances(
...
@@ -56,6 +59,7 @@ void add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_int8_instances(
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
PassThrough
>>>&
instances
);
#endif
template
<
ck
::
index_t
NumDimSpatial
,
template
<
ck
::
index_t
NumDimSpatial
,
typename
InLayout
,
typename
InLayout
,
...
@@ -99,23 +103,29 @@ struct DeviceOperationInstanceFactory<
...
@@ -99,23 +103,29 @@ struct DeviceOperationInstanceFactory<
{
{
add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f32_instances
(
op_ptrs
);
add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f32_instances
(
op_ptrs
);
}
}
#ifdef __fp16__
else
if
constexpr
(
is_same_v
<
InDataType
,
half_t
>
&&
is_same_v
<
WeiDataType
,
half_t
>
&&
else
if
constexpr
(
is_same_v
<
InDataType
,
half_t
>
&&
is_same_v
<
WeiDataType
,
half_t
>
&&
is_same_v
<
OutDataType
,
half_t
>
)
is_same_v
<
OutDataType
,
half_t
>
)
{
{
add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f16_instances
(
op_ptrs
);
add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f16_instances
(
op_ptrs
);
add_device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk_f16_instances
(
op_ptrs
);
add_device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk_f16_instances
(
op_ptrs
);
}
}
#endif
#ifdef __bf16__
else
if
constexpr
(
is_same_v
<
InDataType
,
ck
::
bhalf_t
>
&&
else
if
constexpr
(
is_same_v
<
InDataType
,
ck
::
bhalf_t
>
&&
is_same_v
<
WeiDataType
,
ck
::
bhalf_t
>
&&
is_same_v
<
WeiDataType
,
ck
::
bhalf_t
>
&&
is_same_v
<
OutDataType
,
ck
::
bhalf_t
>
)
is_same_v
<
OutDataType
,
ck
::
bhalf_t
>
)
{
{
add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_bf16_instances
(
op_ptrs
);
add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_bf16_instances
(
op_ptrs
);
}
}
#endif
#ifdef __int8__
else
if
constexpr
(
is_same_v
<
InDataType
,
int8_t
>
&&
is_same_v
<
WeiDataType
,
int8_t
>
&&
else
if
constexpr
(
is_same_v
<
InDataType
,
int8_t
>
&&
is_same_v
<
WeiDataType
,
int8_t
>
&&
is_same_v
<
OutDataType
,
int8_t
>
)
is_same_v
<
OutDataType
,
int8_t
>
)
{
{
add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_int8_instances
(
op_ptrs
);
add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_int8_instances
(
op_ptrs
);
}
}
#endif
}
}
return
op_ptrs
;
return
op_ptrs
;
...
...
library/include/ck/library/tensor_operation_instance/gpu/elementwise_normalization.hpp
View file @
5683ea4e
...
@@ -11,7 +11,7 @@
...
@@ -11,7 +11,7 @@
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
#ifdef __fp16__
namespace
ck
{
namespace
ck
{
namespace
tensor_operation
{
namespace
tensor_operation
{
namespace
device
{
namespace
device
{
...
@@ -77,3 +77,4 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceElemen
...
@@ -77,3 +77,4 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceElemen
}
// namespace device
}
// namespace device
}
// namespace tensor_operation
}
// namespace tensor_operation
}
// namespace ck
}
// namespace ck
#endif
library/include/ck/library/tensor_operation_instance/gpu/gemm.hpp
View file @
5683ea4e
...
@@ -343,6 +343,7 @@ struct DeviceOperationInstanceFactory<
...
@@ -343,6 +343,7 @@ struct DeviceOperationInstanceFactory<
add_device_gemm_xdl_c_shuffle_f32_f32_f32_km_nk_mn_instances
(
op_ptrs
);
add_device_gemm_xdl_c_shuffle_f32_f32_f32_km_nk_mn_instances
(
op_ptrs
);
}
}
}
}
#ifdef __fp16__
else
if
constexpr
(
is_same_v
<
ADataType
,
half_t
>
&&
is_same_v
<
BDataType
,
half_t
>
&&
else
if
constexpr
(
is_same_v
<
ADataType
,
half_t
>
&&
is_same_v
<
BDataType
,
half_t
>
&&
is_same_v
<
CDataType
,
half_t
>
)
is_same_v
<
CDataType
,
half_t
>
)
{
{
...
@@ -388,6 +389,8 @@ struct DeviceOperationInstanceFactory<
...
@@ -388,6 +389,8 @@ struct DeviceOperationInstanceFactory<
add_device_gemm_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instances
(
op_ptrs
);
add_device_gemm_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instances
(
op_ptrs
);
}
}
}
}
#endif
#ifdef __bf16__
else
if
constexpr
(
is_same_v
<
ADataType
,
ck
::
bhalf_t
>
&&
is_same_v
<
BDataType
,
ck
::
bhalf_t
>
&&
else
if
constexpr
(
is_same_v
<
ADataType
,
ck
::
bhalf_t
>
&&
is_same_v
<
BDataType
,
ck
::
bhalf_t
>
&&
is_same_v
<
CDataType
,
ck
::
bhalf_t
>
)
is_same_v
<
CDataType
,
ck
::
bhalf_t
>
)
{
{
...
@@ -412,6 +415,7 @@ struct DeviceOperationInstanceFactory<
...
@@ -412,6 +415,7 @@ struct DeviceOperationInstanceFactory<
add_device_gemm_xdl_c_shuffle_bf16_bf16_bf16_km_nk_mn_instances
(
op_ptrs
);
add_device_gemm_xdl_c_shuffle_bf16_bf16_bf16_km_nk_mn_instances
(
op_ptrs
);
}
}
}
}
#endif
#ifdef __int8__
#ifdef __int8__
else
if
constexpr
(
is_same_v
<
ADataType
,
int8_t
>
&&
is_same_v
<
BDataType
,
int8_t
>
&&
else
if
constexpr
(
is_same_v
<
ADataType
,
int8_t
>
&&
is_same_v
<
BDataType
,
int8_t
>
&&
is_same_v
<
CDataType
,
int8_t
>
)
is_same_v
<
CDataType
,
int8_t
>
)
...
...
library/include/ck/library/tensor_operation_instance/gpu/gemm_add_relu_add_layernorm.hpp
View file @
5683ea4e
...
@@ -9,7 +9,7 @@
...
@@ -9,7 +9,7 @@
#include "ck/ck.hpp"
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/device_gemm_multiple_d_layernorm.hpp"
#include "ck/tensor_operation/gpu/device/device_gemm_multiple_d_layernorm.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
#ifdef __fp16__
namespace
ck
{
namespace
ck
{
namespace
tensor_operation
{
namespace
tensor_operation
{
namespace
device
{
namespace
device
{
...
@@ -170,3 +170,4 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGemmMu
...
@@ -170,3 +170,4 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGemmMu
}
// namespace device
}
// namespace device
}
// namespace tensor_operation
}
// namespace tensor_operation
}
// namespace ck
}
// namespace ck
#endif
library/include/ck/library/tensor_operation_instance/gpu/gemm_bilinear.hpp
View file @
5683ea4e
...
@@ -11,7 +11,7 @@
...
@@ -11,7 +11,7 @@
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
#ifdef __fp16__
namespace
ck
{
namespace
ck
{
namespace
tensor_operation
{
namespace
tensor_operation
{
namespace
device
{
namespace
device
{
...
@@ -144,3 +144,4 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGemmMu
...
@@ -144,3 +144,4 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGemmMu
}
// namespace device
}
// namespace device
}
// namespace tensor_operation
}
// namespace tensor_operation
}
// namespace ck
}
// namespace ck
#endif
library/include/ck/library/tensor_operation_instance/gpu/gemm_fastgelu.hpp
View file @
5683ea4e
...
@@ -10,7 +10,7 @@
...
@@ -10,7 +10,7 @@
#include "ck/ck.hpp"
#include "ck/ck.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
#include "ck/tensor_operation/gpu/device/device_gemm_multiple_d.hpp"
#include "ck/tensor_operation/gpu/device/device_gemm_multiple_d.hpp"
#ifdef __fp16__
namespace
ck
{
namespace
ck
{
namespace
tensor_operation
{
namespace
tensor_operation
{
namespace
device
{
namespace
device
{
...
@@ -136,3 +136,4 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGemmMu
...
@@ -136,3 +136,4 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGemmMu
}
// namespace device
}
// namespace device
}
// namespace tensor_operation
}
// namespace tensor_operation
}
// namespace ck
}
// namespace ck
#endif
library/include/ck/library/tensor_operation_instance/gpu/gemm_streamk.hpp
View file @
5683ea4e
...
@@ -16,7 +16,7 @@ namespace ck {
...
@@ -16,7 +16,7 @@ namespace ck {
namespace
tensor_operation
{
namespace
tensor_operation
{
namespace
device
{
namespace
device
{
namespace
instance
{
namespace
instance
{
#ifdef __fp16__
void
add_device_gemm_xdl_streamk_f16_f16_f16_mk_kn_mn_instances
(
void
add_device_gemm_xdl_streamk_f16_f16_f16_mk_kn_mn_instances
(
std
::
vector
<
std
::
unique_ptr
<
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmStreamK
<
Row
,
Row
,
Row
,
F16
,
F16
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
DeviceGemmStreamK
<
Row
,
Row
,
Row
,
F16
,
F16
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
...
@@ -119,3 +119,4 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGemmSt
...
@@ -119,3 +119,4 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGemmSt
}
// namespace device
}
// namespace device
}
// namespace tensor_operation
}
// namespace tensor_operation
}
// namespace ck
}
// namespace ck
#endif
library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_xdl_instance.hpp
View file @
5683ea4e
This diff is collapsed.
Click to expand it.
library/include/ck/library/tensor_operation_instance/gpu/grouped_gemm.hpp
View file @
5683ea4e
...
@@ -10,7 +10,7 @@
...
@@ -10,7 +10,7 @@
#include "ck/tensor_operation/gpu/device/device_grouped_gemm.hpp"
#include "ck/tensor_operation/gpu/device/device_grouped_gemm.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
#ifdef __fp16__
namespace
ck
{
namespace
ck
{
namespace
tensor_operation
{
namespace
tensor_operation
{
namespace
device
{
namespace
device
{
...
@@ -192,3 +192,4 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
...
@@ -192,3 +192,4 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
}
// namespace device
}
// namespace device
}
// namespace tensor_operation
}
// namespace tensor_operation
}
// namespace ck
}
// namespace ck
#endif
library/include/ck/library/tensor_operation_instance/gpu/normalization.hpp
View file @
5683ea4e
...
@@ -16,7 +16,7 @@ namespace ck {
...
@@ -16,7 +16,7 @@ namespace ck {
namespace
tensor_operation
{
namespace
tensor_operation
{
namespace
device
{
namespace
device
{
namespace
instance
{
namespace
instance
{
#ifdef __fp16__
// FP16
// FP16
void
add_device_normalization_rank_2_1_f16_instances
(
void
add_device_normalization_rank_2_1_f16_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceNormalization
<
F16
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
2
,
1
>>>&
);
std
::
vector
<
std
::
unique_ptr
<
DeviceNormalization
<
F16
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
2
,
1
>>>&
);
...
@@ -26,7 +26,8 @@ void add_device_normalization_rank_4_3_f16_instances(
...
@@ -26,7 +26,8 @@ void add_device_normalization_rank_4_3_f16_instances(
void
add_device_normalization_rank_5_3_f16_instances
(
void
add_device_normalization_rank_5_3_f16_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceNormalization
<
F16
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
5
,
3
>>>&
);
std
::
vector
<
std
::
unique_ptr
<
DeviceNormalization
<
F16
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
5
,
3
>>>&
);
#endif
#ifdef __fp32__
// FP32
// FP32
void
add_device_normalization_rank_2_1_f32_instances
(
void
add_device_normalization_rank_2_1_f32_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceNormalization
<
F32
,
F32
,
F32
,
F32
,
F32
,
PassThrough
,
2
,
1
>>>&
);
std
::
vector
<
std
::
unique_ptr
<
DeviceNormalization
<
F32
,
F32
,
F32
,
F32
,
F32
,
PassThrough
,
2
,
1
>>>&
);
...
@@ -36,7 +37,7 @@ void add_device_normalization_rank_4_3_f32_instances(
...
@@ -36,7 +37,7 @@ void add_device_normalization_rank_4_3_f32_instances(
void
add_device_normalization_rank_5_3_f32_instances
(
void
add_device_normalization_rank_5_3_f32_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceNormalization
<
F32
,
F32
,
F32
,
F32
,
F32
,
PassThrough
,
5
,
3
>>>&
);
std
::
vector
<
std
::
unique_ptr
<
DeviceNormalization
<
F32
,
F32
,
F32
,
F32
,
F32
,
PassThrough
,
5
,
3
>>>&
);
#endif
template
<
typename
XDataType
,
template
<
typename
XDataType
,
typename
GammaDataType
,
typename
GammaDataType
,
typename
BetaDataType
,
typename
BetaDataType
,
...
@@ -65,7 +66,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceNormal
...
@@ -65,7 +66,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceNormal
static
auto
GetInstances
()
static
auto
GetInstances
()
{
{
std
::
vector
<
std
::
unique_ptr
<
DeviceOp
>>
op_ptrs
;
std
::
vector
<
std
::
unique_ptr
<
DeviceOp
>>
op_ptrs
;
#ifdef __fp16__
if
constexpr
(
is_same_v
<
XDataType
,
F16
>
&&
is_same_v
<
GammaDataType
,
F16
>
&&
if
constexpr
(
is_same_v
<
XDataType
,
F16
>
&&
is_same_v
<
GammaDataType
,
F16
>
&&
is_same_v
<
BetaDataType
,
F16
>
&&
is_same_v
<
YDataType
,
F16
>
)
is_same_v
<
BetaDataType
,
F16
>
&&
is_same_v
<
YDataType
,
F16
>
)
{
{
...
@@ -82,7 +83,9 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceNormal
...
@@ -82,7 +83,9 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceNormal
add_device_normalization_rank_5_3_f16_instances
(
op_ptrs
);
add_device_normalization_rank_5_3_f16_instances
(
op_ptrs
);
}
}
}
}
else
if
constexpr
(
is_same_v
<
XDataType
,
F32
>
&&
is_same_v
<
GammaDataType
,
F32
>
&&
#endif
#ifdef __fp32__
if
constexpr
(
is_same_v
<
XDataType
,
F32
>
&&
is_same_v
<
GammaDataType
,
F32
>
&&
is_same_v
<
BetaDataType
,
F32
>
&&
is_same_v
<
YDataType
,
F32
>
)
is_same_v
<
BetaDataType
,
F32
>
&&
is_same_v
<
YDataType
,
F32
>
)
{
{
if
constexpr
(
Rank
==
2
&&
NumReduceDim
==
1
)
if
constexpr
(
Rank
==
2
&&
NumReduceDim
==
1
)
...
@@ -98,7 +101,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceNormal
...
@@ -98,7 +101,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceNormal
add_device_normalization_rank_5_3_f32_instances
(
op_ptrs
);
add_device_normalization_rank_5_3_f32_instances
(
op_ptrs
);
}
}
}
}
#endif
return
op_ptrs
;
return
op_ptrs
;
}
}
};
};
...
...
library/include/ck/library/tensor_operation_instance/gpu/pool2d_fwd.hpp
View file @
5683ea4e
...
@@ -22,7 +22,7 @@ static constexpr auto WindowRank = 2;
...
@@ -22,7 +22,7 @@ static constexpr auto WindowRank = 2;
static
constexpr
auto
MaxOp
=
ck
::
ReduceTensorOp
::
MAX
;
static
constexpr
auto
MaxOp
=
ck
::
ReduceTensorOp
::
MAX
;
static
constexpr
auto
AvgOp
=
ck
::
ReduceTensorOp
::
AVG
;
static
constexpr
auto
AvgOp
=
ck
::
ReduceTensorOp
::
AVG
;
#ifdef __fp16__
// FP16
// FP16
void
add_device_pool2d_fwd_nhwc_f16_instances
(
void
add_device_pool2d_fwd_nhwc_f16_instances
(
std
::
vector
<
std
::
vector
<
...
@@ -36,7 +36,8 @@ void add_device_pool2d_fwd_nhwc_f16_instances(
...
@@ -36,7 +36,8 @@ void add_device_pool2d_fwd_nhwc_f16_instances(
void
add_device_pool2d_fwd_nhwc_index_f16_instances
(
void
add_device_pool2d_fwd_nhwc_index_f16_instances
(
std
::
vector
<
std
::
vector
<
std
::
unique_ptr
<
DevicePoolFwd
<
InOutRank
,
WindowRank
,
F16
,
F16
,
I32
,
MaxOp
,
true
>>>&
);
std
::
unique_ptr
<
DevicePoolFwd
<
InOutRank
,
WindowRank
,
F16
,
F16
,
I32
,
MaxOp
,
true
>>>&
);
#endif
#ifdef __fp32__
// FP32
// FP32
void
add_device_pool2d_fwd_nhwc_f32_instances
(
void
add_device_pool2d_fwd_nhwc_f32_instances
(
std
::
vector
<
std
::
vector
<
...
@@ -50,7 +51,7 @@ void add_device_pool2d_fwd_nhwc_f32_instances(
...
@@ -50,7 +51,7 @@ void add_device_pool2d_fwd_nhwc_f32_instances(
void
add_device_pool2d_fwd_nhwc_index_f32_instances
(
void
add_device_pool2d_fwd_nhwc_index_f32_instances
(
std
::
vector
<
std
::
vector
<
std
::
unique_ptr
<
DevicePoolFwd
<
InOutRank
,
WindowRank
,
F32
,
F32
,
I32
,
MaxOp
,
true
>>>&
);
std
::
unique_ptr
<
DevicePoolFwd
<
InOutRank
,
WindowRank
,
F32
,
F32
,
I32
,
MaxOp
,
true
>>>&
);
#endif
template
<
typename
InDataType
,
template
<
typename
InDataType
,
typename
OutDataType
,
typename
OutDataType
,
typename
IndexDataType
,
typename
IndexDataType
,
...
@@ -75,7 +76,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DevicePoolFw
...
@@ -75,7 +76,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DevicePoolFw
static
auto
GetInstances
()
static
auto
GetInstances
()
{
{
std
::
vector
<
std
::
unique_ptr
<
DeviceOp
>>
op_ptrs
;
std
::
vector
<
std
::
unique_ptr
<
DeviceOp
>>
op_ptrs
;
#ifdef __fp16__
if
constexpr
(
is_same_v
<
InDataType
,
F16
>
&&
is_same_v
<
OutDataType
,
F16
>
&&
if
constexpr
(
is_same_v
<
InDataType
,
F16
>
&&
is_same_v
<
OutDataType
,
F16
>
&&
is_same_v
<
IndexDataType
,
I32
>
)
is_same_v
<
IndexDataType
,
I32
>
)
{
{
...
@@ -88,7 +89,9 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DevicePoolFw
...
@@ -88,7 +89,9 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DevicePoolFw
add_device_pool2d_fwd_nhwc_f16_instances
(
op_ptrs
);
add_device_pool2d_fwd_nhwc_f16_instances
(
op_ptrs
);
}
}
}
}
else
if
constexpr
(
is_same_v
<
InDataType
,
F32
>
&&
is_same_v
<
OutDataType
,
F32
>
&&
#endif
#ifdef __fp32__
if
constexpr
(
is_same_v
<
InDataType
,
F32
>
&&
is_same_v
<
OutDataType
,
F32
>
&&
is_same_v
<
IndexDataType
,
I32
>
)
is_same_v
<
IndexDataType
,
I32
>
)
{
{
if
constexpr
(
OutputIndex
&&
ReduceOpId
==
MaxOp
)
if
constexpr
(
OutputIndex
&&
ReduceOpId
==
MaxOp
)
...
@@ -100,7 +103,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DevicePoolFw
...
@@ -100,7 +103,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DevicePoolFw
add_device_pool2d_fwd_nhwc_f32_instances
(
op_ptrs
);
add_device_pool2d_fwd_nhwc_f32_instances
(
op_ptrs
);
}
}
}
}
#endif
return
op_ptrs
;
return
op_ptrs
;
}
}
};
};
...
...
library/include/ck/library/tensor_operation_instance/gpu/pool3d_fwd.hpp
View file @
5683ea4e
...
@@ -22,7 +22,7 @@ static constexpr auto WindowRank = 3;
...
@@ -22,7 +22,7 @@ static constexpr auto WindowRank = 3;
static
constexpr
auto
MaxOp
=
ck
::
ReduceTensorOp
::
MAX
;
static
constexpr
auto
MaxOp
=
ck
::
ReduceTensorOp
::
MAX
;
static
constexpr
auto
AvgOp
=
ck
::
ReduceTensorOp
::
AVG
;
static
constexpr
auto
AvgOp
=
ck
::
ReduceTensorOp
::
AVG
;
#ifdef __fp16__
// FP16
// FP16
void
add_device_pool3d_fwd_ndhwc_f16_instances
(
void
add_device_pool3d_fwd_ndhwc_f16_instances
(
std
::
vector
<
std
::
vector
<
...
@@ -36,7 +36,8 @@ void add_device_pool3d_fwd_ndhwc_f16_instances(
...
@@ -36,7 +36,8 @@ void add_device_pool3d_fwd_ndhwc_f16_instances(
void
add_device_pool3d_fwd_ndhwc_index_f16_instances
(
void
add_device_pool3d_fwd_ndhwc_index_f16_instances
(
std
::
vector
<
std
::
vector
<
std
::
unique_ptr
<
DevicePoolFwd
<
InOutRank
,
WindowRank
,
F16
,
F16
,
I32
,
MaxOp
,
true
>>>&
);
std
::
unique_ptr
<
DevicePoolFwd
<
InOutRank
,
WindowRank
,
F16
,
F16
,
I32
,
MaxOp
,
true
>>>&
);
#endif
#ifdef __fp32__
// FP32
// FP32
void
add_device_pool3d_fwd_ndhwc_f32_instances
(
void
add_device_pool3d_fwd_ndhwc_f32_instances
(
std
::
vector
<
std
::
vector
<
...
@@ -50,7 +51,7 @@ void add_device_pool3d_fwd_ndhwc_f32_instances(
...
@@ -50,7 +51,7 @@ void add_device_pool3d_fwd_ndhwc_f32_instances(
void
add_device_pool3d_fwd_ndhwc_index_f32_instances
(
void
add_device_pool3d_fwd_ndhwc_index_f32_instances
(
std
::
vector
<
std
::
vector
<
std
::
unique_ptr
<
DevicePoolFwd
<
InOutRank
,
WindowRank
,
F32
,
F32
,
I32
,
MaxOp
,
true
>>>&
);
std
::
unique_ptr
<
DevicePoolFwd
<
InOutRank
,
WindowRank
,
F32
,
F32
,
I32
,
MaxOp
,
true
>>>&
);
#endif
template
<
typename
InDataType
,
template
<
typename
InDataType
,
typename
OutDataType
,
typename
OutDataType
,
typename
IndexDataType
,
typename
IndexDataType
,
...
@@ -75,7 +76,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DevicePoolFw
...
@@ -75,7 +76,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DevicePoolFw
static
auto
GetInstances
()
static
auto
GetInstances
()
{
{
std
::
vector
<
std
::
unique_ptr
<
DeviceOp
>>
op_ptrs
;
std
::
vector
<
std
::
unique_ptr
<
DeviceOp
>>
op_ptrs
;
#ifdef __fp16__
if
constexpr
(
is_same_v
<
InDataType
,
F16
>
&&
is_same_v
<
OutDataType
,
F16
>
&&
if
constexpr
(
is_same_v
<
InDataType
,
F16
>
&&
is_same_v
<
OutDataType
,
F16
>
&&
is_same_v
<
IndexDataType
,
I32
>
)
is_same_v
<
IndexDataType
,
I32
>
)
{
{
...
@@ -88,7 +89,9 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DevicePoolFw
...
@@ -88,7 +89,9 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DevicePoolFw
add_device_pool3d_fwd_ndhwc_f16_instances
(
op_ptrs
);
add_device_pool3d_fwd_ndhwc_f16_instances
(
op_ptrs
);
}
}
}
}
else
if
constexpr
(
is_same_v
<
InDataType
,
F32
>
&&
is_same_v
<
OutDataType
,
F32
>
&&
#endif
#ifdef __fp32__
if
constexpr
(
is_same_v
<
InDataType
,
F32
>
&&
is_same_v
<
OutDataType
,
F32
>
&&
is_same_v
<
IndexDataType
,
I32
>
)
is_same_v
<
IndexDataType
,
I32
>
)
{
{
if
constexpr
(
OutputIndex
&&
ReduceOpId
==
MaxOp
)
if
constexpr
(
OutputIndex
&&
ReduceOpId
==
MaxOp
)
...
@@ -100,7 +103,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DevicePoolFw
...
@@ -100,7 +103,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DevicePoolFw
add_device_pool3d_fwd_ndhwc_f32_instances
(
op_ptrs
);
add_device_pool3d_fwd_ndhwc_f32_instances
(
op_ptrs
);
}
}
}
}
#endif
return
op_ptrs
;
return
op_ptrs
;
}
}
};
};
...
...
library/include/ck/library/tensor_operation_instance/gpu/quantization/gemm_quantization.hpp
View file @
5683ea4e
...
@@ -11,12 +11,12 @@
...
@@ -11,12 +11,12 @@
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
#ifdef __int8__
namespace
ck
{
namespace
ck
{
namespace
tensor_operation
{
namespace
tensor_operation
{
namespace
device
{
namespace
device
{
namespace
instance
{
namespace
instance
{
#ifdef DL_KERNELS
// Layout(A, B, C) = [Col, Row, Row]
// Layout(A, B, C) = [Col, Row, Row]
void
add_device_gemm_quantization_dl_c_shuffle_i8_i8_i8_km_kn_mn_instances
(
void
add_device_gemm_quantization_dl_c_shuffle_i8_i8_i8_km_kn_mn_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleD
<
Col
,
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleD
<
Col
,
...
@@ -76,7 +76,7 @@ void add_device_gemm_quantization_dl_c_shuffle_i8_i8_i8_mk_nk_mn_instances(
...
@@ -76,7 +76,7 @@ void add_device_gemm_quantization_dl_c_shuffle_i8_i8_i8_mk_nk_mn_instances(
PassThrough
,
PassThrough
,
Activation_Mul_Clamp
<
PassThrough
>>>>&
Activation_Mul_Clamp
<
PassThrough
>>>>&
instances
);
instances
);
#endif
// Layout(A, B, C) = [Col, Row, Row]
// Layout(A, B, C) = [Col, Row, Row]
void
add_device_gemm_quantization_xdl_c_shuffle_i8_i8_i8_km_kn_mn_instances
(
void
add_device_gemm_quantization_xdl_c_shuffle_i8_i8_i8_km_kn_mn_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleD
<
Col
,
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleD
<
Col
,
...
@@ -181,7 +181,9 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGemmMu
...
@@ -181,7 +181,9 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGemmMu
{
{
if
constexpr
(
is_same_v
<
Activation
,
PassThrough
>
)
if
constexpr
(
is_same_v
<
Activation
,
PassThrough
>
)
{
{
#ifdef DL_KERNELS
add_device_gemm_quantization_dl_c_shuffle_i8_i8_i8_mk_kn_mn_instances
(
op_ptrs
);
add_device_gemm_quantization_dl_c_shuffle_i8_i8_i8_mk_kn_mn_instances
(
op_ptrs
);
#endif
add_device_gemm_quantization_xdl_c_shuffle_i8_i8_i8_mk_kn_mn_instances
(
op_ptrs
);
add_device_gemm_quantization_xdl_c_shuffle_i8_i8_i8_mk_kn_mn_instances
(
op_ptrs
);
}
}
}
}
...
@@ -190,7 +192,9 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGemmMu
...
@@ -190,7 +192,9 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGemmMu
{
{
if
constexpr
(
is_same_v
<
Activation
,
PassThrough
>
)
if
constexpr
(
is_same_v
<
Activation
,
PassThrough
>
)
{
{
#ifdef DL_KERNELS
add_device_gemm_quantization_dl_c_shuffle_i8_i8_i8_mk_nk_mn_instances
(
op_ptrs
);
add_device_gemm_quantization_dl_c_shuffle_i8_i8_i8_mk_nk_mn_instances
(
op_ptrs
);
#endif
add_device_gemm_quantization_xdl_c_shuffle_i8_i8_i8_mk_nk_mn_instances
(
op_ptrs
);
add_device_gemm_quantization_xdl_c_shuffle_i8_i8_i8_mk_nk_mn_instances
(
op_ptrs
);
}
}
}
}
...
@@ -199,7 +203,9 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGemmMu
...
@@ -199,7 +203,9 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGemmMu
{
{
if
constexpr
(
is_same_v
<
Activation
,
PassThrough
>
)
if
constexpr
(
is_same_v
<
Activation
,
PassThrough
>
)
{
{
#ifdef DL_KERNELS
add_device_gemm_quantization_dl_c_shuffle_i8_i8_i8_km_kn_mn_instances
(
op_ptrs
);
add_device_gemm_quantization_dl_c_shuffle_i8_i8_i8_km_kn_mn_instances
(
op_ptrs
);
#endif
add_device_gemm_quantization_xdl_c_shuffle_i8_i8_i8_km_kn_mn_instances
(
op_ptrs
);
add_device_gemm_quantization_xdl_c_shuffle_i8_i8_i8_km_kn_mn_instances
(
op_ptrs
);
}
}
}
}
...
@@ -208,7 +214,9 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGemmMu
...
@@ -208,7 +214,9 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGemmMu
{
{
if
constexpr
(
is_same_v
<
Activation
,
PassThrough
>
)
if
constexpr
(
is_same_v
<
Activation
,
PassThrough
>
)
{
{
#ifdef DL_KERNELS
add_device_gemm_quantization_dl_c_shuffle_i8_i8_i8_km_nk_mn_instances
(
op_ptrs
);
add_device_gemm_quantization_dl_c_shuffle_i8_i8_i8_km_nk_mn_instances
(
op_ptrs
);
#endif
add_device_gemm_quantization_xdl_c_shuffle_i8_i8_i8_km_nk_mn_instances
(
op_ptrs
);
add_device_gemm_quantization_xdl_c_shuffle_i8_i8_i8_km_nk_mn_instances
(
op_ptrs
);
}
}
}
}
...
@@ -222,3 +230,4 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGemmMu
...
@@ -222,3 +230,4 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGemmMu
}
// namespace device
}
// namespace device
}
// namespace tensor_operation
}
// namespace tensor_operation
}
// namespace ck
}
// namespace ck
#endif
\ No newline at end of file
library/include/ck/library/tensor_operation_instance/gpu/quantization/grouped_convolution_bias_forward_perchannel_quantization.hpp
View file @
5683ea4e
...
@@ -11,12 +11,12 @@
...
@@ -11,12 +11,12 @@
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
#ifdef __int8__
namespace
ck
{
namespace
ck
{
namespace
tensor_operation
{
namespace
tensor_operation
{
namespace
device
{
namespace
device
{
namespace
instance
{
namespace
instance
{
#ifdef DL_KERNELS
// grouped conv2d forward, NHWGC/GKYXC/NHWGK
// grouped conv2d forward, NHWGC/GKYXC/NHWGK
void
add_device_conv2d_dl_bias_perchannel_quantization_int8_instances
(
void
add_device_conv2d_dl_bias_perchannel_quantization_int8_instances
(
std
::
vector
<
std
::
vector
<
...
@@ -64,7 +64,7 @@ void add_device_conv2d_dl_bias_tanh_perchannel_quantization_int8_instances(
...
@@ -64,7 +64,7 @@ void add_device_conv2d_dl_bias_tanh_perchannel_quantization_int8_instances(
PassThrough
,
PassThrough
,
Add_Mul2_Activation_Mul_Clamp
<
TanH
>>>>&
Add_Mul2_Activation_Mul_Clamp
<
TanH
>>>>&
instances
);
instances
);
#endif
void
add_device_conv2d_xdl_bias_perchannel_quantization_int8_instances
(
void
add_device_conv2d_xdl_bias_perchannel_quantization_int8_instances
(
std
::
vector
<
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvFwdMultipleD
<
2
,
std
::
unique_ptr
<
DeviceGroupedConvFwdMultipleD
<
2
,
...
@@ -163,12 +163,16 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
...
@@ -163,12 +163,16 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
{
{
if
constexpr
(
is_same_v
<
Activation
,
PassThrough
>
)
if
constexpr
(
is_same_v
<
Activation
,
PassThrough
>
)
{
{
#ifdef DL_KERNELS
add_device_conv2d_dl_bias_perchannel_quantization_int8_instances
(
op_ptrs
);
add_device_conv2d_dl_bias_perchannel_quantization_int8_instances
(
op_ptrs
);
#endif
add_device_conv2d_xdl_bias_perchannel_quantization_int8_instances
(
op_ptrs
);
add_device_conv2d_xdl_bias_perchannel_quantization_int8_instances
(
op_ptrs
);
}
}
else
if
constexpr
(
is_same_v
<
Activation
,
Relu
>
)
else
if
constexpr
(
is_same_v
<
Activation
,
Relu
>
)
{
{
#ifdef DL_KERNELS
add_device_conv2d_dl_bias_relu_perchannel_quantization_int8_instances
(
op_ptrs
);
add_device_conv2d_dl_bias_relu_perchannel_quantization_int8_instances
(
op_ptrs
);
#endif
add_device_conv2d_xdl_bias_relu_perchannel_quantization_int8_instances
(
op_ptrs
);
add_device_conv2d_xdl_bias_relu_perchannel_quantization_int8_instances
(
op_ptrs
);
}
}
}
}
...
@@ -229,7 +233,9 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
...
@@ -229,7 +233,9 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
{
{
if
constexpr
(
is_same_v
<
Activation
,
TanH
>
)
if
constexpr
(
is_same_v
<
Activation
,
TanH
>
)
{
{
#ifdef DL_KERNELS
add_device_conv2d_dl_bias_tanh_perchannel_quantization_int8_instances
(
op_ptrs
);
add_device_conv2d_dl_bias_tanh_perchannel_quantization_int8_instances
(
op_ptrs
);
#endif
add_device_conv2d_xdl_bias_tanh_perchannel_quantization_int8_instances
(
op_ptrs
);
add_device_conv2d_xdl_bias_tanh_perchannel_quantization_int8_instances
(
op_ptrs
);
}
}
}
}
...
@@ -243,3 +249,4 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
...
@@ -243,3 +249,4 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
}
// namespace device
}
// namespace device
}
// namespace tensor_operation
}
// namespace tensor_operation
}
// namespace ck
}
// namespace ck
#endif
library/include/ck/library/tensor_operation_instance/gpu/quantization/grouped_convolution_bias_forward_perlayer_quantization.hpp
View file @
5683ea4e
...
@@ -11,12 +11,12 @@
...
@@ -11,12 +11,12 @@
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
#ifdef __int8__
namespace
ck
{
namespace
ck
{
namespace
tensor_operation
{
namespace
tensor_operation
{
namespace
device
{
namespace
device
{
namespace
instance
{
namespace
instance
{
#ifdef DL_KERNELS
// grouped conv2d forward, NHWGC/GKYXC/NHWGK
// grouped conv2d forward, NHWGC/GKYXC/NHWGK
void
add_device_conv2d_dl_bias_perlayer_quantization_int8_instances
(
void
add_device_conv2d_dl_bias_perlayer_quantization_int8_instances
(
std
::
vector
<
std
::
vector
<
...
@@ -63,7 +63,7 @@ void add_device_conv2d_dl_bias_tanh_perlayer_quantization_int8_instances(
...
@@ -63,7 +63,7 @@ void add_device_conv2d_dl_bias_tanh_perlayer_quantization_int8_instances(
PassThrough
,
PassThrough
,
Add_Mul_Activation_Mul_Clamp
<
TanH
>>>>&
Add_Mul_Activation_Mul_Clamp
<
TanH
>>>>&
instances
);
instances
);
#endif
void
add_device_conv2d_xdl_bias_perlayer_quantization_int8_instances
(
void
add_device_conv2d_xdl_bias_perlayer_quantization_int8_instances
(
std
::
vector
<
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvFwdMultipleD
<
2
,
std
::
unique_ptr
<
DeviceGroupedConvFwdMultipleD
<
2
,
...
@@ -161,12 +161,16 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
...
@@ -161,12 +161,16 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
{
{
if
constexpr
(
is_same_v
<
Activation
,
PassThrough
>
)
if
constexpr
(
is_same_v
<
Activation
,
PassThrough
>
)
{
{
#ifdef DL_KERNELS
add_device_conv2d_dl_bias_perlayer_quantization_int8_instances
(
op_ptrs
);
add_device_conv2d_dl_bias_perlayer_quantization_int8_instances
(
op_ptrs
);
#endif
add_device_conv2d_xdl_bias_perlayer_quantization_int8_instances
(
op_ptrs
);
add_device_conv2d_xdl_bias_perlayer_quantization_int8_instances
(
op_ptrs
);
}
}
else
if
constexpr
(
is_same_v
<
Activation
,
Relu
>
)
else
if
constexpr
(
is_same_v
<
Activation
,
Relu
>
)
{
{
#ifdef DL_KERNELS
add_device_conv2d_dl_bias_relu_perlayer_quantization_int8_instances
(
op_ptrs
);
add_device_conv2d_dl_bias_relu_perlayer_quantization_int8_instances
(
op_ptrs
);
#endif
add_device_conv2d_xdl_bias_relu_perlayer_quantization_int8_instances
(
op_ptrs
);
add_device_conv2d_xdl_bias_relu_perlayer_quantization_int8_instances
(
op_ptrs
);
}
}
}
}
...
@@ -227,7 +231,9 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
...
@@ -227,7 +231,9 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
{
{
if
constexpr
(
is_same_v
<
Activation
,
TanH
>
)
if
constexpr
(
is_same_v
<
Activation
,
TanH
>
)
{
{
#ifdef DL_KERNELS
add_device_conv2d_dl_bias_tanh_perlayer_quantization_int8_instances
(
op_ptrs
);
add_device_conv2d_dl_bias_tanh_perlayer_quantization_int8_instances
(
op_ptrs
);
#endif
add_device_conv2d_xdl_bias_tanh_perlayer_quantization_int8_instances
(
op_ptrs
);
add_device_conv2d_xdl_bias_tanh_perlayer_quantization_int8_instances
(
op_ptrs
);
}
}
}
}
...
@@ -241,3 +247,4 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
...
@@ -241,3 +247,4 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
}
// namespace device
}
// namespace device
}
// namespace tensor_operation
}
// namespace tensor_operation
}
// namespace ck
}
// namespace ck
#endif
library/include/ck/library/tensor_operation_instance/gpu/quantization/grouped_convolution_forward_perchannel_quantization.hpp
View file @
5683ea4e
...
@@ -11,12 +11,12 @@
...
@@ -11,12 +11,12 @@
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
#ifdef __int8__
namespace
ck
{
namespace
ck
{
namespace
tensor_operation
{
namespace
tensor_operation
{
namespace
device
{
namespace
device
{
namespace
instance
{
namespace
instance
{
#ifdef DL_KERNELS
// grouped conv2d forward, NHWGC/GKYXC/NHWGK
// grouped conv2d forward, NHWGC/GKYXC/NHWGK
void
add_device_conv2d_dl_perchannel_quantization_int8_instances
(
void
add_device_conv2d_dl_perchannel_quantization_int8_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvFwdMultipleD
<
2
,
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvFwdMultipleD
<
2
,
...
@@ -47,7 +47,7 @@ void add_device_conv2d_dl_relu_perchannel_quantization_int8_instances(
...
@@ -47,7 +47,7 @@ void add_device_conv2d_dl_relu_perchannel_quantization_int8_instances(
PassThrough
,
PassThrough
,
Activation_Mul2_Clamp
<
Relu
>>>>&
Activation_Mul2_Clamp
<
Relu
>>>>&
instances
);
instances
);
#endif
void
add_device_conv2d_xdl_perchannel_quantization_int8_instances
(
void
add_device_conv2d_xdl_perchannel_quantization_int8_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvFwdMultipleD
<
2
,
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvFwdMultipleD
<
2
,
NHWGC
,
NHWGC
,
...
@@ -128,12 +128,16 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
...
@@ -128,12 +128,16 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
{
{
if
constexpr
(
is_same_v
<
Activation
,
PassThrough
>
)
if
constexpr
(
is_same_v
<
Activation
,
PassThrough
>
)
{
{
#ifdef DL_KERNELS
add_device_conv2d_dl_perchannel_quantization_int8_instances
(
op_ptrs
);
add_device_conv2d_dl_perchannel_quantization_int8_instances
(
op_ptrs
);
#endif
add_device_conv2d_xdl_perchannel_quantization_int8_instances
(
op_ptrs
);
add_device_conv2d_xdl_perchannel_quantization_int8_instances
(
op_ptrs
);
}
}
else
if
constexpr
(
is_same_v
<
Activation
,
Relu
>
)
else
if
constexpr
(
is_same_v
<
Activation
,
Relu
>
)
{
{
#ifdef DL_KERNELS
add_device_conv2d_dl_relu_perchannel_quantization_int8_instances
(
op_ptrs
);
add_device_conv2d_dl_relu_perchannel_quantization_int8_instances
(
op_ptrs
);
#endif
add_device_conv2d_xdl_relu_perchannel_quantization_int8_instances
(
op_ptrs
);
add_device_conv2d_xdl_relu_perchannel_quantization_int8_instances
(
op_ptrs
);
}
}
}
}
...
@@ -147,3 +151,4 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
...
@@ -147,3 +151,4 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
}
// namespace device
}
// namespace device
}
// namespace tensor_operation
}
// namespace tensor_operation
}
// namespace ck
}
// namespace ck
#endif
Prev
1
2
3
4
5
6
7
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment