Commit 615e1d3e authored by Chao Liu's avatar Chao Liu
Browse files

update conv 1d and 3d instance

parent 11edd0f0
...@@ -17,6 +17,35 @@ namespace tensor_operation { ...@@ -17,6 +17,35 @@ namespace tensor_operation {
namespace device { namespace device {
namespace instance { namespace instance {
// conv1d_fwd
void add_device_conv1d_fwd_xdl_nwc_kxc_nwk_bf16_instances(
std::vector<std::unique_ptr<
DeviceConvFwd<1, NWC, KXC, NWK, BF16, BF16, BF16, PassThrough, PassThrough, PassThrough>>>&
instances);
void add_device_conv1d_fwd_xdl_nwc_kxc_nwk_f16_instances(
std::vector<std::unique_ptr<
DeviceConvFwd<1, NWC, KXC, NWK, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
instances);
void add_device_conv1d_fwd_xdl_nwc_kxc_nwk_f32_instances(
std::vector<std::unique_ptr<
DeviceConvFwd<1, NWC, KXC, NWK, F32, F32, F32, PassThrough, PassThrough, PassThrough>>>&
instances);
void add_device_conv1d_fwd_xdl_nwc_kxc_nwk_int8_instances(
std::vector<std::unique_ptr<DeviceConvFwd<1,
NWC,
KXC,
NWK,
int8_t,
int8_t,
int8_t,
PassThrough,
PassThrough,
PassThrough>>>& instances);
// conv2d_fwd
void add_device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk_f16_instances( void add_device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk_f16_instances(
std::vector<std::unique_ptr< std::vector<std::unique_ptr<
DeviceConvFwd<2, NHWC, KYXC, NHWK, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>& DeviceConvFwd<2, NHWC, KYXC, NHWK, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
...@@ -90,6 +119,55 @@ void add_device_convnd_2d_fwd_xdl_nhwc_kyxc_nhwk_int8_instances( ...@@ -90,6 +119,55 @@ void add_device_convnd_2d_fwd_xdl_nhwc_kyxc_nhwk_int8_instances(
PassThrough, PassThrough,
PassThrough>>>& instances); PassThrough>>>& instances);
// conv3d_fwd
void add_device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk_bf16_instances(
std::vector<std::unique_ptr<DeviceConvFwd<3,
NDHWC,
KZYXC,
NDHWK,
BF16,
BF16,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk_f16_instances(
std::vector<std::unique_ptr<DeviceConvFwd<3,
NDHWC,
KZYXC,
NDHWK,
F16,
F16,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk_f32_instances(
std::vector<std::unique_ptr<DeviceConvFwd<3,
NDHWC,
KZYXC,
NDHWK,
F32,
F32,
F32,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk_int8_instances(
std::vector<std::unique_ptr<DeviceConvFwd<3,
NDHWC,
KZYXC,
NDHWK,
int8_t,
int8_t,
int8_t,
PassThrough,
PassThrough,
PassThrough>>>& instances);
template <ck::index_t NumDimSpatial, template <ck::index_t NumDimSpatial,
typename InLayout, typename InLayout,
typename WeiLayout, typename WeiLayout,
...@@ -124,7 +202,32 @@ struct DeviceOperationInstanceFactory< ...@@ -124,7 +202,32 @@ struct DeviceOperationInstanceFactory<
{ {
std::vector<std::unique_ptr<DeviceOp>> op_ptrs; std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
if constexpr(NumDimSpatial == 2 && is_same_v<InLayout, NHWC> && if constexpr(NumDimSpatial == 1 && is_same_v<InLayout, NWC> && is_same_v<WeiLayout, KXC> &&
is_same_v<OutLayout, NWK>)
{
if constexpr(is_same_v<InDataType, float> && is_same_v<WeiDataType, float> &&
is_same_v<OutDataType, float>)
{
add_device_conv1d_fwd_xdl_nwc_kxc_nwk_f32_instances(op_ptrs);
}
else if constexpr(is_same_v<InDataType, half_t> && is_same_v<WeiDataType, half_t> &&
is_same_v<OutDataType, half_t>)
{
add_device_conv1d_fwd_xdl_nwc_kxc_nwk_f16_instances(op_ptrs);
}
else if constexpr(is_same_v<InDataType, ck::bhalf_t> &&
is_same_v<WeiDataType, ck::bhalf_t> &&
is_same_v<OutDataType, ck::bhalf_t>)
{
add_device_conv1d_fwd_xdl_nwc_kxc_nwk_bf16_instances(op_ptrs);
}
else if constexpr(is_same_v<InDataType, int8_t> && is_same_v<WeiDataType, int8_t> &&
is_same_v<OutDataType, int8_t>)
{
add_device_conv1d_fwd_xdl_nwc_kxc_nwk_int8_instances(op_ptrs);
}
}
else if constexpr(NumDimSpatial == 2 && is_same_v<InLayout, NHWC> &&
is_same_v<WeiLayout, KYXC> && is_same_v<OutLayout, NHWK>) is_same_v<WeiLayout, KYXC> && is_same_v<OutLayout, NHWK>)
{ {
if constexpr(is_same_v<InDataType, float> && is_same_v<WeiDataType, float> && if constexpr(is_same_v<InDataType, float> && is_same_v<WeiDataType, float> &&
...@@ -153,6 +256,31 @@ struct DeviceOperationInstanceFactory< ...@@ -153,6 +256,31 @@ struct DeviceOperationInstanceFactory<
add_device_convnd_2d_fwd_xdl_nhwc_kyxc_nhwk_int8_instances(op_ptrs); add_device_convnd_2d_fwd_xdl_nhwc_kyxc_nhwk_int8_instances(op_ptrs);
} }
} }
else if constexpr(NumDimSpatial == 3 && is_same_v<InLayout, NDHWC> &&
is_same_v<WeiLayout, KZYXC> && is_same_v<OutLayout, NDHWK>)
{
if constexpr(is_same_v<InDataType, float> && is_same_v<WeiDataType, float> &&
is_same_v<OutDataType, float>)
{
add_device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk_f32_instances(op_ptrs);
}
else if constexpr(is_same_v<InDataType, half_t> && is_same_v<WeiDataType, half_t> &&
is_same_v<OutDataType, half_t>)
{
add_device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk_f16_instances(op_ptrs);
}
else if constexpr(is_same_v<InDataType, ck::bhalf_t> &&
is_same_v<WeiDataType, ck::bhalf_t> &&
is_same_v<OutDataType, ck::bhalf_t>)
{
add_device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk_bf16_instances(op_ptrs);
}
else if constexpr(is_same_v<InDataType, int8_t> && is_same_v<WeiDataType, int8_t> &&
is_same_v<OutDataType, int8_t>)
{
add_device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk_int8_instances(op_ptrs);
}
}
return op_ptrs; return op_ptrs;
} }
......
...@@ -37,9 +37,9 @@ target_link_libraries(ckProfiler PRIVATE utility) ...@@ -37,9 +37,9 @@ target_link_libraries(ckProfiler PRIVATE utility)
#target_link_libraries(ckProfiler PRIVATE device_batched_gemm_instance) #target_link_libraries(ckProfiler PRIVATE device_batched_gemm_instance)
#target_link_libraries(ckProfiler PRIVATE device_batched_gemm_reduce_instance) #target_link_libraries(ckProfiler PRIVATE device_batched_gemm_reduce_instance)
#target_link_libraries(ckProfiler PRIVATE device_grouped_gemm_instance) #target_link_libraries(ckProfiler PRIVATE device_grouped_gemm_instance)
#target_link_libraries(ckProfiler PRIVATE device_conv1d_fwd_instance) target_link_libraries(ckProfiler PRIVATE device_conv1d_fwd_instance)
target_link_libraries(ckProfiler PRIVATE device_conv2d_fwd_instance) target_link_libraries(ckProfiler PRIVATE device_conv2d_fwd_instance)
#target_link_libraries(ckProfiler PRIVATE device_conv3d_fwd_instance) target_link_libraries(ckProfiler PRIVATE device_conv3d_fwd_instance)
#target_link_libraries(ckProfiler PRIVATE device_conv2d_fwd_bias_relu_instance) #target_link_libraries(ckProfiler PRIVATE device_conv2d_fwd_bias_relu_instance)
#target_link_libraries(ckProfiler PRIVATE device_conv2d_fwd_bias_relu_add_instance) #target_link_libraries(ckProfiler PRIVATE device_conv2d_fwd_bias_relu_add_instance)
#target_link_libraries(ckProfiler PRIVATE device_convnd_bwd_data_instance) #target_link_libraries(ckProfiler PRIVATE device_convnd_bwd_data_instance)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment